From 07cbf2545f705d0448631f479a18d0b86b7055dc Mon Sep 17 00:00:00 2001 From: immutablet Date: Wed, 12 Sep 2018 14:56:44 -0700 Subject: [PATCH] Lazily dial kms-plugin. --- .../server/options/encryptionconfig/config.go | 4 +- .../options/encryptionconfig/config_test.go | 3 +- .../value/encrypt/envelope/grpc_service.go | 79 +++--- .../envelope/grpc_service_unix_test.go | 234 +++++++++++++++--- 4 files changed, 249 insertions(+), 71 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config.go b/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config.go index 4c6fd5a39cf..feece9f029a 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config.go +++ b/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config.go @@ -24,6 +24,7 @@ import ( "io" "io/ioutil" "os" + "time" yaml "github.com/ghodss/yaml" @@ -40,6 +41,7 @@ const ( aesGCMTransformerPrefixV1 = "k8s:enc:aesgcm:v1:" secretboxTransformerPrefixV1 = "k8s:enc:secretbox:v1:" kmsTransformerPrefixV1 = "k8s:enc:kms:v1:" + kmsPluginConnectionTimeout = 3 * time.Second ) // GetTransformerOverrides returns the transformer overrides by reading and parsing the encryption provider configuration file @@ -160,7 +162,7 @@ func GetPrefixTransformers(config *ResourceConfig) ([]value.PrefixTransformer, e } // Get gRPC client service with endpoint. - envelopeService, err := envelopeServiceFactory(provider.KMS.Endpoint) + envelopeService, err := envelopeServiceFactory(provider.KMS.Endpoint, kmsPluginConnectionTimeout) if err != nil { return nil, fmt.Errorf("could not configure KMS plugin %q, error: %v", provider.KMS.Name, err) } diff --git a/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config_test.go b/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config_test.go index 957d7dedeb9..8cd6027e904 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config_test.go @@ -21,6 +21,7 @@ import ( "encoding/base64" "strings" "testing" + "time" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apiserver/pkg/storage/value" @@ -239,7 +240,7 @@ func (t *testEnvelopeService) Encrypt(data []byte) ([]byte, error) { } // The factory method to create mock envelope service. -func newMockEnvelopeService(endpoint string) (envelope.Service, error) { +func newMockEnvelopeService(endpoint string, timeout time.Duration) (envelope.Service, error) { return &testEnvelopeService{}, nil } diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/grpc_service.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/grpc_service.go index b29b621786a..a39ceeca0da 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/grpc_service.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/grpc_service.go @@ -23,6 +23,7 @@ import ( "net" "net/url" "strings" + "sync" "time" "github.com/golang/glog" @@ -39,19 +40,20 @@ const ( // Current version for the protocol interface definition. kmsapiVersion = "v1beta1" - // The timeout that communicate with KMS server. - timeout = 30 * time.Second + versionErrorf = "KMS provider api version %s is not supported, only %s is supported now" ) // The gRPC implementation for envelope.Service. type gRPCService struct { - // gRPC client instance - kmsClient kmsapi.KeyManagementServiceClient - connection *grpc.ClientConn + kmsClient kmsapi.KeyManagementServiceClient + connection *grpc.ClientConn + callTimeout time.Duration + mux sync.RWMutex + versionChecked bool } // NewGRPCService returns an envelope.Service which use gRPC to communicate the remote KMS provider. -func NewGRPCService(endpoint string) (Service, error) { +func NewGRPCService(endpoint string, callTimeout time.Duration) (Service, error) { glog.V(4).Infof("Configure KMS provider with endpoint: %s", endpoint) addr, err := parseEndpoint(endpoint) @@ -59,28 +61,28 @@ func NewGRPCService(endpoint string) (Service, error) { return nil, err } - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() + connection, err := grpc.Dial(addr, grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.FailFast(false)), grpc.WithDialer( + func(string, time.Duration) (net.Conn, error) { + // Ignoring addr and timeout arguments: + // addr - comes from the closure + // timeout - is ignored since we are connecting in a non-blocking configuration + c, err := net.DialTimeout(unixProtocol, addr, 0) + if err != nil { + glog.Errorf("failed to create connection to unix socket: %s, error: %v", addr, err) + } + return c, err + })) - connection, err := grpc.DialContext(ctx, addr, grpc.WithInsecure(), grpc.WithDialer(unixDial)) if err != nil { - return nil, fmt.Errorf("connect remote KMS provider %q failed, error: %v", addr, err) + return nil, fmt.Errorf("failed to create connection to %s, error: %v", endpoint, err) } kmsClient := kmsapi.NewKeyManagementServiceClient(connection) - - err = checkAPIVersion(kmsClient) - if err != nil { - connection.Close() - return nil, fmt.Errorf("failed check version for %q, error: %v", addr, err) - } - - return &gRPCService{kmsClient: kmsClient, connection: connection}, nil -} - -// This dialer explicitly ask gRPC to use unix socket as network. -func unixDial(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout(unixProtocol, addr, timeout) + return &gRPCService{ + kmsClient: kmsClient, + connection: connection, + callTimeout: callTimeout, + }, nil } // Parse the endpoint to extract schema, host or path. @@ -109,31 +111,37 @@ func parseEndpoint(endpoint string) (string, error) { return u.Path, nil } -// Check the KMS provider API version. -// Only matching kmsapiVersion is supported now. -func checkAPIVersion(kmsClient kmsapi.KeyManagementServiceClient) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() +func (g *gRPCService) checkAPIVersion(ctx context.Context) error { + g.mux.Lock() + defer g.mux.Unlock() + + if g.versionChecked { + return nil + } request := &kmsapi.VersionRequest{Version: kmsapiVersion} - response, err := kmsClient.Version(ctx, request) + response, err := g.kmsClient.Version(ctx, request) if err != nil { return fmt.Errorf("failed get version from remote KMS provider: %v", err) } if response.Version != kmsapiVersion { - return fmt.Errorf("KMS provider api version %s is not supported, only %s is supported now", - response.Version, kmsapiVersion) + return fmt.Errorf(versionErrorf, response.Version, kmsapiVersion) } + g.versionChecked = true - glog.V(4).Infof("KMS provider %s initialized, version: %s", response.RuntimeName, response.RuntimeVersion) + glog.V(4).Infof("Version of KMS provider is %s", response.Version) return nil } // Decrypt a given data string to obtain the original byte data. func (g *gRPCService) Decrypt(cipher []byte) ([]byte, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), g.callTimeout) defer cancel() + if err := g.checkAPIVersion(ctx); err != nil { + return nil, err + } + request := &kmsapi.DecryptRequest{Cipher: cipher, Version: kmsapiVersion} response, err := g.kmsClient.Decrypt(ctx, request) if err != nil { @@ -144,8 +152,11 @@ func (g *gRPCService) Decrypt(cipher []byte) ([]byte, error) { // Encrypt bytes to a string ciphertext. func (g *gRPCService) Encrypt(plain []byte) ([]byte, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), g.callTimeout) defer cancel() + if err := g.checkAPIVersion(ctx); err != nil { + return nil, err + } request := &kmsapi.EncryptRequest{Plain: plain, Version: kmsapiVersion} response, err := g.kmsClient.Encrypt(ctx, request) diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/grpc_service_unix_test.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/grpc_service_unix_test.go index 6e0aa12bad6..bc40b220c27 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/grpc_service_unix_test.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/grpc_service_unix_test.go @@ -25,7 +25,9 @@ import ( "fmt" "net" "reflect" + "sync" "testing" + "time" "google.golang.org/grpc" @@ -36,17 +38,143 @@ const ( endpoint = "unix:///@kms-socket.sock" ) -// Normal encryption and decryption operation. -func TestGRPCService(t *testing.T) { - // Start a test gRPC server. - server, err := startTestKMSProvider() +// TestKMSPluginLateStart tests the scenario where kms-plugin pod/container starts after kube-apiserver pod/container. +// Since the Dial to kms-plugin is non-blocking we expect the construction of gRPC service to succeed even when +// kms-plugin is not yet up - dialing happens in the background. +func TestKMSPluginLateStart(t *testing.T) { + callTimeout := 3 * time.Second + + service, err := NewGRPCService(endpoint, callTimeout) + if err != nil { + t.Fatalf("failed to create envelope service, error: %v", err) + } + defer destroyService(service) + + time.Sleep(callTimeout / 2) + f, err := startFakeKMSProvider(kmsapiVersion) if err != nil { t.Fatalf("failed to start test KMS provider server, error: %v", err) } - defer stopTestKMSProvider(server) + defer f.server.Stop() + + data := []byte("test data") + _, err = service.Encrypt(data) + if err != nil { + t.Fatalf("failed when execute encrypt, error: %v", err) + } +} + +// TestIntermittentConnectionLoss tests the scenario where the connection with kms-plugin is intermittently lost. +func TestIntermittentConnectionLoss(t *testing.T) { + var ( + wg1 sync.WaitGroup + wg2 sync.WaitGroup + timeout = 30 * time.Second + blackOut = 1 * time.Second + data = []byte("test data") + ) + // Start KMS Plugin + f, err := startFakeKMSProvider(kmsapiVersion) + if err != nil { + t.Fatalf("failed to start test KMS provider server, error: %v", err) + } + + // connect to kms plugin + service, err := NewGRPCService(endpoint, timeout) + if err != nil { + t.Fatalf("failed to create envelope service, error: %v", err) + } + defer destroyService(service) + + _, err = service.Encrypt(data) + if err != nil { + t.Fatalf("failed when execute encrypt, error: %v", err) + } + t.Log("Connected to KMSPlugin") + + // Stop KMS Plugin - simulating connection loss + f.server.Stop() + t.Log("KMS Plugin is stopped") + + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + // Call service to encrypt data. + t.Log("Sending encrypt request") + wg1.Done() + _, err := service.Encrypt(data) + if err != nil { + t.Fatalf("failed when executing encrypt, error: %v", err) + } + }() + + wg1.Wait() + time.Sleep(blackOut) + // Start KMS Plugin + f, err = startFakeKMSProvider(kmsapiVersion) + if err != nil { + t.Fatalf("failed to start test KMS provider server, error: %v", err) + } + defer f.server.Stop() + t.Log("Restarted KMS Plugin") + + wg2.Wait() +} + +func TestUnsupportedVersion(t *testing.T) { + ver := "invalid" + data := []byte("test data") + wantErr := fmt.Errorf(versionErrorf, ver, kmsapiVersion) + + f, err := startFakeKMSProvider(ver) + if err != nil { + t.Fatalf("failed to start test KMS provider server, error: %ver", err) + } + defer f.server.Stop() + + s, err := NewGRPCService(endpoint, 1*time.Second) + if err != nil { + t.Fatal(err) + } + defer destroyService(s) + + // Encrypt + _, err = s.Encrypt(data) + if err == nil || err.Error() != wantErr.Error() { + t.Errorf("got err: %ver, want: %ver", err, wantErr) + } + + destroyService(s) + + s, err = NewGRPCService(endpoint, 1*time.Second) + if err != nil { + t.Fatal(err) + } + defer destroyService(s) + + // Decrypt + _, err = s.Decrypt(data) + if err == nil || err.Error() != wantErr.Error() { + t.Errorf("got err: %ver, want: %ver", err, wantErr) + } +} + +func TestConcurrentAccess(t *testing.T) { + +} + +// Normal encryption and decryption operation. +func TestGRPCService(t *testing.T) { + // Start a test gRPC server. + f, err := startFakeKMSProvider(kmsapiVersion) + if err != nil { + t.Fatalf("failed to start test KMS provider server, error: %v", err) + } + defer f.server.Stop() // Create the gRPC client service. - service, err := NewGRPCService(endpoint) + service, err := NewGRPCService(endpoint, 1*time.Second) if err != nil { t.Fatalf("failed to create envelope service, error: %v", err) } @@ -70,19 +198,65 @@ func TestGRPCService(t *testing.T) { } } +// Normal encryption and decryption operation by multiple go-routines. +func TestGRPCServiceConcurrentAccess(t *testing.T) { + // Start a test gRPC server. + f, err := startFakeKMSProvider(kmsapiVersion) + if err != nil { + t.Fatalf("failed to start test KMS provider server, error: %v", err) + } + defer f.server.Stop() + + // Create the gRPC client service. + service, err := NewGRPCService(endpoint, 1*time.Second) + if err != nil { + t.Fatalf("failed to create envelope service, error: %v", err) + } + defer destroyService(service) + + var wg sync.WaitGroup + n := 1000 + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + // Call service to encrypt data. + data := []byte("test data") + cipher, err := service.Encrypt(data) + if err != nil { + t.Errorf("failed when execute encrypt, error: %v", err) + } + + // Call service to decrypt data. + result, err := service.Decrypt(cipher) + if err != nil { + t.Errorf("failed when execute decrypt, error: %v", err) + } + + if !reflect.DeepEqual(data, result) { + t.Errorf("expect: %v, but: %v", data, result) + } + }() + } + + wg.Wait() +} + func destroyService(service Service) { - s := service.(*gRPCService) - s.connection.Close() + if service != nil { + s := service.(*gRPCService) + s.connection.Close() + } } // Test all those invalid configuration for KMS provider. func TestInvalidConfiguration(t *testing.T) { // Start a test gRPC server. - server, err := startTestKMSProvider() + f, err := startFakeKMSProvider(kmsapiVersion) if err != nil { t.Fatalf("failed to start test KMS provider server, error: %v", err) } - defer stopTestKMSProvider(server) + defer f.server.Stop() invalidConfigs := []struct { name string @@ -91,16 +265,12 @@ func TestInvalidConfiguration(t *testing.T) { }{ {"emptyConfiguration", kmsapiVersion, ""}, {"invalidScheme", kmsapiVersion, "tcp://localhost:6060"}, - {"unavailableEndpoint", kmsapiVersion, unixProtocol + ":///kms-socket.nonexist"}, - {"invalidAPIVersion", "invalidVersion", endpoint}, } for _, testCase := range invalidConfigs { t.Run(testCase.name, func(t *testing.T) { - setAPIVersion(testCase.apiVersion) - defer setAPIVersion(kmsapiVersion) - - _, err := NewGRPCService(testCase.endpoint) + f.apiVersion = testCase.apiVersion + _, err := NewGRPCService(testCase.endpoint, 1*time.Second) if err == nil { t.Fatalf("should fail to create envelope service for %s.", testCase.name) } @@ -109,7 +279,7 @@ func TestInvalidConfiguration(t *testing.T) { } // Start the gRPC server that listens on unix socket. -func startTestKMSProvider() (*grpc.Server, error) { +func startFakeKMSProvider(version string) (*fakeKMSPlugin, error) { sockFile, err := parseEndpoint(endpoint) if err != nil { return nil, fmt.Errorf("failed to parse endpoint:%q, error %v", endpoint, err) @@ -119,31 +289,25 @@ func startTestKMSProvider() (*grpc.Server, error) { return nil, fmt.Errorf("failed to listen on the unix socket, error: %v", err) } - server := grpc.NewServer() - kmsapi.RegisterKeyManagementServiceServer(server, &base64Server{}) - go server.Serve(listener) - return server, nil -} - -func stopTestKMSProvider(server *grpc.Server) { - server.Stop() + s := grpc.NewServer() + f := &fakeKMSPlugin{apiVersion: version, server: s} + kmsapi.RegisterKeyManagementServiceServer(s, f) + go s.Serve(listener) + return f, nil } // Fake gRPC sever for remote KMS provider. // Use base64 to simulate encrypt and decrypt. -type base64Server struct{} - -var testProviderAPIVersion = kmsapiVersion - -func setAPIVersion(apiVersion string) { - testProviderAPIVersion = apiVersion +type fakeKMSPlugin struct { + apiVersion string + server *grpc.Server } -func (s *base64Server) Version(ctx context.Context, request *kmsapi.VersionRequest) (*kmsapi.VersionResponse, error) { - return &kmsapi.VersionResponse{Version: testProviderAPIVersion, RuntimeName: "testKMS", RuntimeVersion: "0.0.1"}, nil +func (s *fakeKMSPlugin) Version(ctx context.Context, request *kmsapi.VersionRequest) (*kmsapi.VersionResponse, error) { + return &kmsapi.VersionResponse{Version: s.apiVersion, RuntimeName: "testKMS", RuntimeVersion: "0.0.1"}, nil } -func (s *base64Server) Decrypt(ctx context.Context, request *kmsapi.DecryptRequest) (*kmsapi.DecryptResponse, error) { +func (s *fakeKMSPlugin) Decrypt(ctx context.Context, request *kmsapi.DecryptRequest) (*kmsapi.DecryptResponse, error) { buf := make([]byte, base64.StdEncoding.DecodedLen(len(request.Cipher))) n, err := base64.StdEncoding.Decode(buf, request.Cipher) if err != nil { @@ -153,7 +317,7 @@ func (s *base64Server) Decrypt(ctx context.Context, request *kmsapi.DecryptReque return &kmsapi.DecryptResponse{Plain: buf[:n]}, nil } -func (s *base64Server) Encrypt(ctx context.Context, request *kmsapi.EncryptRequest) (*kmsapi.EncryptResponse, error) { +func (s *fakeKMSPlugin) Encrypt(ctx context.Context, request *kmsapi.EncryptRequest) (*kmsapi.EncryptResponse, error) { buf := make([]byte, base64.StdEncoding.EncodedLen(len(request.Plain))) base64.StdEncoding.Encode(buf, request.Plain) return &kmsapi.EncryptResponse{Cipher: buf}, nil