mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-21 10:51:29 +00:00
Merge pull request #116630 from aramase/aramase/c/rm_key_hierarchy
[KMSv2] remove key hierarchy in reference implementation
This commit is contained in:
commit
9bb778d48e
@ -10,19 +10,18 @@ require (
|
|||||||
k8s.io/apimachinery v0.0.0
|
k8s.io/apimachinery v0.0.0
|
||||||
k8s.io/client-go v0.0.0
|
k8s.io/client-go v0.0.0
|
||||||
k8s.io/klog/v2 v2.90.1
|
k8s.io/klog/v2 v2.90.1
|
||||||
k8s.io/utils v0.0.0-20230209194617-a36077c30491
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-logr/logr v1.2.3 // indirect
|
github.com/go-logr/logr v1.2.3 // indirect
|
||||||
github.com/golang/protobuf v1.5.3 // indirect
|
github.com/golang/protobuf v1.5.3 // indirect
|
||||||
github.com/google/uuid v1.3.0 // indirect
|
|
||||||
golang.org/x/net v0.8.0 // indirect
|
golang.org/x/net v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.6.0 // indirect
|
golang.org/x/sys v0.6.0 // indirect
|
||||||
golang.org/x/text v0.8.0 // indirect
|
golang.org/x/text v0.8.0 // indirect
|
||||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
|
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
|
||||||
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
|
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
|
||||||
google.golang.org/protobuf v1.28.1 // indirect
|
google.golang.org/protobuf v1.28.1 // indirect
|
||||||
|
k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace (
|
replace (
|
||||||
|
2
staging/src/k8s.io/kms/go.sum
generated
2
staging/src/k8s.io/kms/go.sum
generated
@ -50,8 +50,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
|
||||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
|
@ -11,7 +11,6 @@ require (
|
|||||||
github.com/go-logr/logr v1.2.3 // indirect
|
github.com/go-logr/logr v1.2.3 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/golang/protobuf v1.5.3 // indirect
|
github.com/golang/protobuf v1.5.3 // indirect
|
||||||
github.com/google/uuid v1.3.0 // indirect
|
|
||||||
golang.org/x/net v0.8.0 // indirect
|
golang.org/x/net v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.6.0 // indirect
|
golang.org/x/sys v0.6.0 // indirect
|
||||||
golang.org/x/text v0.8.0 // indirect
|
golang.org/x/text v0.8.0 // indirect
|
||||||
@ -19,7 +18,6 @@ require (
|
|||||||
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
|
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
|
||||||
google.golang.org/grpc v1.51.0 // indirect
|
google.golang.org/grpc v1.51.0 // indirect
|
||||||
google.golang.org/protobuf v1.28.1 // indirect
|
google.golang.org/protobuf v1.28.1 // indirect
|
||||||
k8s.io/apimachinery v0.0.0 // indirect
|
|
||||||
k8s.io/client-go v0.0.0 // indirect
|
k8s.io/client-go v0.0.0 // indirect
|
||||||
k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect
|
k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect
|
||||||
)
|
)
|
||||||
|
@ -50,8 +50,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
|
||||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
|
@ -26,7 +26,6 @@ import (
|
|||||||
|
|
||||||
"k8s.io/klog/v2"
|
"k8s.io/klog/v2"
|
||||||
"k8s.io/kms/internal"
|
"k8s.io/kms/internal"
|
||||||
"k8s.io/kms/pkg/hierarchy"
|
|
||||||
"k8s.io/kms/pkg/service"
|
"k8s.io/kms/pkg/service"
|
||||||
"k8s.io/kms/pkg/util"
|
"k8s.io/kms/pkg/util"
|
||||||
)
|
)
|
||||||
@ -55,14 +54,20 @@ func main() {
|
|||||||
grpcService := service.NewGRPCService(
|
grpcService := service.NewGRPCService(
|
||||||
addr,
|
addr,
|
||||||
*timeout,
|
*timeout,
|
||||||
hierarchy.NewLocalKEKService(ctx, remoteKMSService),
|
remoteKMSService,
|
||||||
)
|
)
|
||||||
|
|
||||||
klog.InfoS("starting server", "listenAddr", *listenAddr)
|
klog.InfoS("starting server", "listenAddr", *listenAddr)
|
||||||
if err := grpcService.ListenAndServe(); err != nil {
|
go func() {
|
||||||
klog.ErrorS(err, "failed to serve")
|
if err := grpcService.ListenAndServe(); err != nil {
|
||||||
os.Exit(1)
|
klog.ErrorS(err, "failed to serve")
|
||||||
}
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
klog.InfoS("shutting down server")
|
||||||
|
grpcService.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
// withShutdownSignal returns a copy of the parent context that will close if
|
// withShutdownSignal returns a copy of the parent context that will close if
|
||||||
|
@ -1,456 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright 2023 The Kubernetes Authors.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package hierarchy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto/aes"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"k8s.io/apimachinery/pkg/util/uuid"
|
|
||||||
"k8s.io/apimachinery/pkg/util/wait"
|
|
||||||
"k8s.io/klog/v2"
|
|
||||||
aestransformer "k8s.io/kms/pkg/encrypt/aes"
|
|
||||||
"k8s.io/kms/pkg/service"
|
|
||||||
"k8s.io/kms/pkg/value"
|
|
||||||
"k8s.io/utils/clock"
|
|
||||||
"k8s.io/utils/lru"
|
|
||||||
)
|
|
||||||
|
|
||||||
// localKEK is a struct that holds the local KEK and the remote KMS response.
|
|
||||||
type localKEK struct {
|
|
||||||
encKEK []byte
|
|
||||||
usage atomic.Uint64
|
|
||||||
expiry time.Time
|
|
||||||
transformer value.Transformer
|
|
||||||
remoteKMSResponse *service.EncryptResponse
|
|
||||||
generatedAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
// emptyContext is an empty slice of bytes. This is passed as value.Context to the
|
|
||||||
// GCM transformer. The grpc interface does not provide any additional authenticated data
|
|
||||||
// to use with AEAD.
|
|
||||||
emptyContext = value.DefaultContext([]byte{})
|
|
||||||
// errInvalidKMSAnnotationKeySuffix is returned when the annotation key suffix is not allowed.
|
|
||||||
errInvalidKMSAnnotationKeySuffix = fmt.Errorf("annotation keys are not allowed to use %s", referenceSuffix)
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
referenceSuffix = ".reference.encryption.k8s.io"
|
|
||||||
// referenceKEKAnnotationKey is the key used to store the localKEK in the annotations.
|
|
||||||
referenceKEKAnnotationKey = "encrypted-kek" + referenceSuffix
|
|
||||||
numAnnotations = 1
|
|
||||||
cacheSize = 1_000
|
|
||||||
|
|
||||||
// localKEKGenerationPollInterval is the interval at which the local KEK is checked for rotation.
|
|
||||||
localKEKGenerationPollInterval = 1 * time.Minute
|
|
||||||
|
|
||||||
// keyLength is the length of the local KEK in bytes.
|
|
||||||
// This is the same length used for the DEKs generated in kube-apiserver.
|
|
||||||
keyLength = 32
|
|
||||||
// keyMaxUsage is the maximum number of times an AES GCM key can be used
|
|
||||||
// with a random nonce: 2^32. The local KEK is a transformer that hold an
|
|
||||||
// AES GCM key. It is based on recommendations from
|
|
||||||
// https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf.
|
|
||||||
// It is reduced by one to be comparable with a atomic.Uint32.
|
|
||||||
// We picked a arbitrary number that is less than the max usage of the local KEK.
|
|
||||||
keyMaxUsage = 1<<22 - 1
|
|
||||||
// keySuggestedUsage is a threshold that triggers the rotation of a new local KEK. It means that half
|
|
||||||
// the number of times a local KEK can be used has been reached.
|
|
||||||
keySuggestedUsage = 1 << 21
|
|
||||||
// keyMaxAge is the maximum age of a local KEK. It is not a cryptographic necessity.
|
|
||||||
keyMaxAge = 7 * 24 * time.Hour
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ service.Service = &LocalKEKService{}
|
|
||||||
|
|
||||||
// LocalKEKService adds an additional KEK layer to reduce calls to the remote KMS.
|
|
||||||
// The local KEK is generated at startup in a controller loop and stored in the
|
|
||||||
// LocalKEKService. This KEK is used for all encryption operations. For the decrypt
|
|
||||||
// operation, if the encrypted local KEK is not found in the cache, the remote KMS
|
|
||||||
// is used to decrypt the local KEK.
|
|
||||||
type LocalKEKService struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
// remoteKMS is the remote kms that is used to encrypt and decrypt the local KEKs.
|
|
||||||
remoteKMS service.Service
|
|
||||||
// localKEKTracker is a atomic pointer to avoid locks. This is used to store the local KEK.
|
|
||||||
localKEKTracker atomic.Pointer[localKEK]
|
|
||||||
// transformers is a thread-safe LRU cache which caches decrypted DEKs indexed by their encrypted form.
|
|
||||||
// The cache is only used for the decrypt operation.
|
|
||||||
transformers *lru.Cache
|
|
||||||
// isReady is an atomic boolean that indicates if the localKEK service is ready for encryption.
|
|
||||||
isReady atomic.Bool
|
|
||||||
|
|
||||||
keyMaxUsage uint64
|
|
||||||
keySuggestedUsage uint64
|
|
||||||
keyMaxAge time.Duration
|
|
||||||
|
|
||||||
pollInterval time.Duration
|
|
||||||
|
|
||||||
clock clock.Clock
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLocalKEKService is being initialized with a remote KMS service.
|
|
||||||
// The local KEK is generated in a controller loop. The local KEK is used for all
|
|
||||||
// encryption operations.
|
|
||||||
func NewLocalKEKService(ctx context.Context, remoteService service.Service) *LocalKEKService {
|
|
||||||
return newLocalKEKService(ctx, remoteService, keyMaxUsage, keySuggestedUsage, keyMaxAge, localKEKGenerationPollInterval, clock.RealClock{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func newLocalKEKService(ctx context.Context, remoteService service.Service, maxUsage, suggestedUsage uint64, maxAge, pollInterval time.Duration, clock clock.Clock) *LocalKEKService {
|
|
||||||
localKEKService := &LocalKEKService{
|
|
||||||
remoteKMS: remoteService,
|
|
||||||
transformers: lru.New(cacheSize),
|
|
||||||
keyMaxUsage: maxUsage,
|
|
||||||
keySuggestedUsage: suggestedUsage,
|
|
||||||
keyMaxAge: maxAge,
|
|
||||||
pollInterval: pollInterval,
|
|
||||||
clock: clock,
|
|
||||||
}
|
|
||||||
go localKEKService.run(ctx)
|
|
||||||
return localKEKService
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run method creates a new local KEK when the following thresholds are met:
|
|
||||||
// - the local KEK is used more often than keySuggestedUsage times or
|
|
||||||
// - the local KEK is older than a localExpiry.
|
|
||||||
//
|
|
||||||
// this method starts the controller and blocks until the context is cancelled.
|
|
||||||
func (m *LocalKEKService) run(ctx context.Context) {
|
|
||||||
// same as wait.UntilWithContext but with a custom clock
|
|
||||||
wait.BackoffUntil(func() {
|
|
||||||
lk := m.getLocalKEK()
|
|
||||||
// this is the first time the local KEK is generated
|
|
||||||
localKEKNotGenerated := lk.transformer == nil
|
|
||||||
// the local KEK is used more often than keySuggestedUsage times
|
|
||||||
localKEKUsageThresholdReached := lk.usage.Load() > m.keySuggestedUsage
|
|
||||||
// the local KEK is older than the expiry
|
|
||||||
localKEKExpired := m.clock.Now().After(lk.expiry)
|
|
||||||
|
|
||||||
if localKEKNotGenerated || localKEKUsageThresholdReached || localKEKExpired {
|
|
||||||
uid := string(uuid.NewUUID())
|
|
||||||
err := m.generateLocalKEK(ctx, uid, "")
|
|
||||||
if err == nil {
|
|
||||||
m.isReady.Store(true)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
klog.V(2).ErrorS(err, "failed to generate local KEK", "uid", uid)
|
|
||||||
// if the local KEK is expired and we cannot generate a new one, we set
|
|
||||||
// isReady to false because we can no longer encrypt new data.
|
|
||||||
if localKEKExpired {
|
|
||||||
m.isReady.Store(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, wait.NewJitteredBackoffManager(m.pollInterval, 0, m.clock), true, ctx.Done())
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTransformerForEncryption returns the local KEK as localTransformer, the corresponding
|
|
||||||
// remoteKMSResponse and a potential error.
|
|
||||||
// On every use the localUsage is incremented by one.
|
|
||||||
// It is assumed that only one encryption will happen with the returned transformer.
|
|
||||||
func (m *LocalKEKService) getTransformerForEncryption(uid string) (value.Transformer, *service.EncryptResponse, error) {
|
|
||||||
lk := m.getLocalKEK()
|
|
||||||
// localKEK is not initialized yet
|
|
||||||
if lk.transformer == nil {
|
|
||||||
return nil, nil, fmt.Errorf("local KEK is not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.clock.Now().After(lk.expiry) {
|
|
||||||
return nil, nil, fmt.Errorf("local KEK has expired at %v", lk.expiry)
|
|
||||||
}
|
|
||||||
|
|
||||||
if counter := lk.usage.Add(1); counter >= m.keyMaxUsage {
|
|
||||||
return nil, nil, fmt.Errorf("local KEK has reached maximum usage of %d", keyMaxUsage)
|
|
||||||
}
|
|
||||||
|
|
||||||
return lk.transformer, lk.remoteKMSResponse, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt encrypts the plaintext with the localKEK.
|
|
||||||
func (m *LocalKEKService) Encrypt(ctx context.Context, uid string, pt []byte) (*service.EncryptResponse, error) {
|
|
||||||
transformer, resp, err := m.getTransformerForEncryption(uid)
|
|
||||||
if err != nil {
|
|
||||||
klog.V(2).ErrorS(err, "failed to get transformer for encryption", "uid", uid)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ct, err := transformer.TransformToStorage(ctx, pt, emptyContext)
|
|
||||||
if err != nil {
|
|
||||||
klog.V(2).ErrorS(err, "failed to encrypt data", "uid", uid)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &service.EncryptResponse{
|
|
||||||
Ciphertext: ct,
|
|
||||||
KeyID: resp.KeyID,
|
|
||||||
Annotations: resp.Annotations,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTransformerForDecryption returns the transformer for the given encryptedKEK.
|
|
||||||
// - If the encryptedKEK is the current localKEK, the localKEK is returned.
|
|
||||||
// - If the encryptedKEK is not the current localKEK, the cache is checked.
|
|
||||||
// - If the encryptedKEK is not found in the cache, the remote KMS is used to decrypt the encryptedKEK.
|
|
||||||
func (m *LocalKEKService) getTransformerForDecryption(ctx context.Context, uid string, req *service.DecryptRequest) (value.Transformer, error) {
|
|
||||||
encKEK := req.Annotations[referenceKEKAnnotationKey]
|
|
||||||
|
|
||||||
// check if the key required for decryption is the current local KEK
|
|
||||||
// that's being used for encryption
|
|
||||||
lk := m.getLocalKEK()
|
|
||||||
if lk.transformer != nil && bytes.Equal(lk.encKEK, encKEK) {
|
|
||||||
return lk.transformer, nil
|
|
||||||
}
|
|
||||||
// check if the key required for decryption is already in the cache
|
|
||||||
if _transformer, found := m.transformers.Get(base64.StdEncoding.EncodeToString(encKEK)); found {
|
|
||||||
return _transformer.(value.Transformer), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := m.remoteKMS.Decrypt(ctx, uid, &service.DecryptRequest{
|
|
||||||
Ciphertext: encKEK,
|
|
||||||
KeyID: req.KeyID,
|
|
||||||
Annotations: annotationsWithoutReferenceKeys(req.Annotations),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
block, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
transformer := aestransformer.NewGCMTransformer(block)
|
|
||||||
|
|
||||||
// Overwrite the plain key with 0s.
|
|
||||||
copy(key, make([]byte, len(key)))
|
|
||||||
|
|
||||||
m.transformers.Add(base64.StdEncoding.EncodeToString(encKEK), transformer)
|
|
||||||
|
|
||||||
return transformer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decrypt attempts to decrypt the ciphertext with the localKEK, a KEK from the
|
|
||||||
// store, or the remote KMS.
|
|
||||||
func (m *LocalKEKService) Decrypt(ctx context.Context, uid string, req *service.DecryptRequest) ([]byte, error) {
|
|
||||||
if _, ok := req.Annotations[referenceKEKAnnotationKey]; !ok {
|
|
||||||
return nil, fmt.Errorf("unable to find local KEK for request with uid %q", uid)
|
|
||||||
}
|
|
||||||
|
|
||||||
transformer, err := m.getTransformerForDecryption(ctx, uid, req)
|
|
||||||
if err != nil {
|
|
||||||
klog.V(2).ErrorS(err, "failed to get transformer for decryption", "uid", uid)
|
|
||||||
return nil, fmt.Errorf("failed to get transformer for decryption: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pt, _, err := transformer.TransformFromStorage(ctx, req.Ciphertext, emptyContext)
|
|
||||||
if err != nil {
|
|
||||||
klog.V(2).ErrorS(err, "failed to decrypt data", "uid", uid)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return pt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status returns the status of the remote KMS.
|
|
||||||
func (m *LocalKEKService) Status(ctx context.Context) (*service.StatusResponse, error) {
|
|
||||||
resp, err := m.remoteKMS.Status(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validateRemoteKMSStatusResponse(resp); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
r := copyStatusResponse(resp)
|
|
||||||
// if the remote KMS KeyID has changed, we need to rotate the local KEK
|
|
||||||
lk := m.getLocalKEK()
|
|
||||||
if lk.transformer != nil && r.KeyID != lk.remoteKMSResponse.KeyID {
|
|
||||||
if err := m.rotateLocalKEK(ctx, r.KeyID); err != nil {
|
|
||||||
klog.ErrorS(err, "failed to rotate local KEK", "expectedKeyID", r.KeyID, "currentKeyID", lk.remoteKMSResponse.KeyID)
|
|
||||||
// if rotation fails, we will overwrite the keyID to the one we are currently using
|
|
||||||
// for encryption as localKEKService is the authoritative source for the keyID.
|
|
||||||
r.KeyID = lk.remoteKMSResponse.KeyID
|
|
||||||
// TODO(aramase): we are currently not returning the error if rotation fails. We should
|
|
||||||
// allow the failed state for an arbitrary time period and return the error if the state
|
|
||||||
// is not eventually fixed.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var aggregateHealthz []string
|
|
||||||
if r.Healthz != "ok" {
|
|
||||||
aggregateHealthz = append(aggregateHealthz, r.Healthz)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !m.isReady.Load() {
|
|
||||||
// if the localKEKService is not ready, we will set the healthz status to not ready
|
|
||||||
klog.V(2).InfoS("localKEKService is not ready", "keyID", r.KeyID)
|
|
||||||
aggregateHealthz = append(aggregateHealthz, "localKEKService is not ready")
|
|
||||||
}
|
|
||||||
if len(aggregateHealthz) > 0 {
|
|
||||||
r.Healthz = strings.Join(aggregateHealthz, "; ")
|
|
||||||
}
|
|
||||||
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rotateLocalKEK rotates the local KEK by generating a new local KEK and encrypting it with the
|
|
||||||
// remote KMS.
|
|
||||||
func (m *LocalKEKService) rotateLocalKEK(ctx context.Context, expectedKeyID string) error {
|
|
||||||
uid := string(uuid.NewUUID())
|
|
||||||
if err := m.generateLocalKEK(ctx, uid, expectedKeyID); err != nil {
|
|
||||||
klog.V(2).ErrorS(err, "failed to generate local KEK as part of rotation", "uid", uid)
|
|
||||||
return fmt.Errorf("[uid=%s] failed to generate local KEK as part of rotation: %w", uid, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateLocalKEK generates a new local KEK and encrypts it with the remote KMS.
|
|
||||||
// if expectedKeyID is not empty, it will check if the keyID returned from the remote KMS matches
|
|
||||||
// the expected keyID. If the keyID does not match, it will continue using the existing local KEK
|
|
||||||
// and return an error.
|
|
||||||
func (m *LocalKEKService) generateLocalKEK(ctx context.Context, uid, expectedKeyID string) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
lk := m.getLocalKEK()
|
|
||||||
// if the localKEK was generated in the last pollInterval duration, we will not generate a new
|
|
||||||
// localKEK. This is to avoid regenerating a new localKEK for queued requests.
|
|
||||||
if lk.transformer != nil && m.clock.Since(lk.generatedAt) < m.pollInterval {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := generateKey(keyLength)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate local KEK: %w", err)
|
|
||||||
}
|
|
||||||
block, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create cipher block: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := m.remoteKMS.Encrypt(ctx, uid, key)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to encrypt local KEK: %w", err)
|
|
||||||
}
|
|
||||||
if err = validateRemoteKMSEncryptResponse(resp); err != nil {
|
|
||||||
return fmt.Errorf("invalid response from remote KMS: %w", err)
|
|
||||||
}
|
|
||||||
if expectedKeyID != "" && resp.KeyID != expectedKeyID {
|
|
||||||
return fmt.Errorf("keyID returned from remote KMS does not match expected keyID")
|
|
||||||
}
|
|
||||||
|
|
||||||
now := m.clock.Now()
|
|
||||||
m.localKEKTracker.Store(&localKEK{
|
|
||||||
encKEK: resp.Ciphertext,
|
|
||||||
expiry: now.Add(m.keyMaxAge),
|
|
||||||
usage: atomic.Uint64{},
|
|
||||||
transformer: aestransformer.NewGCMTransformer(block),
|
|
||||||
remoteKMSResponse: copyResponseAndAddLocalKEKAnnotation(resp),
|
|
||||||
generatedAt: now,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *LocalKEKService) getLocalKEK() *localKEK {
|
|
||||||
lk := m.localKEKTracker.Load()
|
|
||||||
if lk == nil {
|
|
||||||
return &localKEK{}
|
|
||||||
}
|
|
||||||
return lk
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyResponseAndAddLocalKEKAnnotation returns a copy of the remoteKMSResponse with the
|
|
||||||
// referenceKEKAnnotationKey added to the annotations.
|
|
||||||
func copyResponseAndAddLocalKEKAnnotation(resp *service.EncryptResponse) *service.EncryptResponse {
|
|
||||||
annotations := make(map[string][]byte, len(resp.Annotations)+numAnnotations)
|
|
||||||
for s, bytes := range resp.Annotations {
|
|
||||||
s := s
|
|
||||||
bytes := bytes
|
|
||||||
annotations[s] = bytes
|
|
||||||
}
|
|
||||||
annotations[referenceKEKAnnotationKey] = resp.Ciphertext
|
|
||||||
|
|
||||||
return &service.EncryptResponse{
|
|
||||||
// Ciphertext is not set on purpose - it is different per Encrypt call
|
|
||||||
KeyID: resp.KeyID,
|
|
||||||
Annotations: annotations,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyStatusResponse returns a copy of the remote KMS status response.
|
|
||||||
func copyStatusResponse(resp *service.StatusResponse) *service.StatusResponse {
|
|
||||||
return &service.StatusResponse{
|
|
||||||
Healthz: resp.Healthz,
|
|
||||||
Version: resp.Version,
|
|
||||||
KeyID: resp.KeyID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// annotationsWithoutReferenceKeys returns a copy of the annotations without the reference implementation
|
|
||||||
// annotations.
|
|
||||||
func annotationsWithoutReferenceKeys(annotations map[string][]byte) map[string][]byte {
|
|
||||||
if len(annotations) <= numAnnotations {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m := make(map[string][]byte, len(annotations)-numAnnotations)
|
|
||||||
for k, v := range annotations {
|
|
||||||
k, v := k, v
|
|
||||||
if strings.HasSuffix(k, referenceSuffix) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m[k] = v
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateRemoteKMSEncryptResponse validates the EncryptResponse from the remote KMS.
|
|
||||||
func validateRemoteKMSEncryptResponse(resp *service.EncryptResponse) error {
|
|
||||||
// validate annotations don't contain the reference implementation annotations
|
|
||||||
for k := range resp.Annotations {
|
|
||||||
if strings.HasSuffix(k, referenceSuffix) {
|
|
||||||
return errInvalidKMSAnnotationKeySuffix
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateRemoteKMSStatusResponse validates the StatusResponse from the remote KMS.
|
|
||||||
func validateRemoteKMSStatusResponse(resp *service.StatusResponse) error {
|
|
||||||
if len(resp.KeyID) == 0 {
|
|
||||||
return fmt.Errorf("keyID is empty")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateKey generates a random key using system randomness.
|
|
||||||
func generateKey(length int) (key []byte, err error) {
|
|
||||||
key = make([]byte, length)
|
|
||||||
if _, err = rand.Read(key); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
@ -1,833 +0,0 @@
|
|||||||
/*
|
|
||||||
Copyright 2023 The Kubernetes Authors.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package hierarchy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"k8s.io/apimachinery/pkg/util/rand"
|
|
||||||
"k8s.io/apimachinery/pkg/util/wait"
|
|
||||||
"k8s.io/kms/pkg/service"
|
|
||||||
testingclock "k8s.io/utils/clock/testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
testAnnotationKey = "version.encryption.remote.io"
|
|
||||||
testAnnotationKeyVersion = "key-version.encryption.remote.io"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
input *service.EncryptResponse
|
|
||||||
want *service.EncryptResponse
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "annotations is nil",
|
|
||||||
input: &service.EncryptResponse{
|
|
||||||
Ciphertext: []byte("encryptedLocalKEK"),
|
|
||||||
KeyID: "keyID",
|
|
||||||
Annotations: nil,
|
|
||||||
},
|
|
||||||
want: &service.EncryptResponse{
|
|
||||||
KeyID: "keyID",
|
|
||||||
Annotations: map[string][]byte{
|
|
||||||
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remote KMS sent 1 annotation",
|
|
||||||
input: &service.EncryptResponse{
|
|
||||||
Ciphertext: []byte("encryptedLocalKEK"),
|
|
||||||
KeyID: "keyID",
|
|
||||||
Annotations: map[string][]byte{
|
|
||||||
testAnnotationKey: []byte("1"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: &service.EncryptResponse{
|
|
||||||
KeyID: "keyID",
|
|
||||||
Annotations: map[string][]byte{
|
|
||||||
testAnnotationKey: []byte("1"),
|
|
||||||
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remote KMS sent 2 annotations",
|
|
||||||
input: &service.EncryptResponse{
|
|
||||||
Ciphertext: []byte("encryptedLocalKEK"),
|
|
||||||
KeyID: "keyID",
|
|
||||||
Annotations: map[string][]byte{
|
|
||||||
testAnnotationKey: []byte("1"),
|
|
||||||
testAnnotationKeyVersion: []byte("2"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: &service.EncryptResponse{
|
|
||||||
KeyID: "keyID",
|
|
||||||
Annotations: map[string][]byte{
|
|
||||||
testAnnotationKey: []byte("1"),
|
|
||||||
testAnnotationKeyVersion: []byte("2"),
|
|
||||||
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
tc := tc
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
got := copyResponseAndAddLocalKEKAnnotation(tc.input)
|
|
||||||
if !reflect.DeepEqual(got, tc.want) {
|
|
||||||
t.Errorf("copyResponseAndAddLocalKEKAnnotation(%v) = %v, want %v", tc.input, got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAnnotationsWithoutReferenceKeys(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
input map[string][]byte
|
|
||||||
want map[string][]byte
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "annotations is nil",
|
|
||||||
input: nil,
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "annotations is empty",
|
|
||||||
input: map[string][]byte{},
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "annotations only contains reference keys",
|
|
||||||
input: map[string][]byte{
|
|
||||||
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
|
|
||||||
},
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "annotations contains 1 reference key and 1 other key",
|
|
||||||
input: map[string][]byte{
|
|
||||||
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
|
|
||||||
testAnnotationKey: []byte("1"),
|
|
||||||
},
|
|
||||||
want: map[string][]byte{
|
|
||||||
testAnnotationKey: []byte("1"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
tc := tc
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
got := annotationsWithoutReferenceKeys(tc.input)
|
|
||||||
if !reflect.DeepEqual(got, tc.want) {
|
|
||||||
t.Errorf("annotationsWithoutReferenceKeys(%v) = %v, want %v", tc.input, got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateRemoteKMSEncryptResponse(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
input *service.EncryptResponse
|
|
||||||
want error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "annotations is nil",
|
|
||||||
input: &service.EncryptResponse{},
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "annotation key contains reference suffix",
|
|
||||||
input: &service.EncryptResponse{
|
|
||||||
Annotations: map[string][]byte{
|
|
||||||
"version.reference.encryption.k8s.io": []byte("1"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: errInvalidKMSAnnotationKeySuffix,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no annotation key contains reference suffix",
|
|
||||||
input: &service.EncryptResponse{
|
|
||||||
Annotations: map[string][]byte{
|
|
||||||
testAnnotationKey: []byte("1"),
|
|
||||||
testAnnotationKeyVersion: []byte("2"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
tc := tc
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
got := validateRemoteKMSEncryptResponse(tc.input)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("validateRemoteKMSResponse(%v) = %v, want %v", tc.input, got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateRemoteKMSStatusResponse(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
input *service.StatusResponse
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "keyID is empty",
|
|
||||||
input: &service.StatusResponse{
|
|
||||||
KeyID: "",
|
|
||||||
},
|
|
||||||
wantErr: "keyID is empty",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no error",
|
|
||||||
input: &service.StatusResponse{
|
|
||||||
KeyID: "keyID",
|
|
||||||
},
|
|
||||||
wantErr: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
tc := tc
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
got := validateRemoteKMSStatusResponse(tc.input)
|
|
||||||
if tc.wantErr != "" {
|
|
||||||
if got == nil {
|
|
||||||
t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
|
|
||||||
}
|
|
||||||
if !strings.Contains(got.Error(), tc.wantErr) {
|
|
||||||
t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if got != nil {
|
|
||||||
t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ service.Service = &testRemoteService{}
|
|
||||||
|
|
||||||
type testRemoteService struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
|
|
||||||
keyID string
|
|
||||||
disabled bool
|
|
||||||
encryptCallCount int
|
|
||||||
decryptCallCount int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testRemoteService) Encrypt(ctx context.Context, uid string, plaintext []byte) (*service.EncryptResponse, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
s.encryptCallCount++
|
|
||||||
if s.disabled {
|
|
||||||
return nil, errors.New("failed to encrypt")
|
|
||||||
}
|
|
||||||
return &service.EncryptResponse{
|
|
||||||
KeyID: s.keyID,
|
|
||||||
Ciphertext: []byte(base64.StdEncoding.EncodeToString(plaintext)),
|
|
||||||
Annotations: map[string][]byte{
|
|
||||||
testAnnotationKey: []byte("1"),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testRemoteService) Decrypt(ctx context.Context, uid string, req *service.DecryptRequest) ([]byte, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
s.decryptCallCount++
|
|
||||||
if s.disabled {
|
|
||||||
return nil, errors.New("failed to decrypt")
|
|
||||||
}
|
|
||||||
if len(req.Annotations) != 1 {
|
|
||||||
return nil, errors.New("invalid annotations")
|
|
||||||
}
|
|
||||||
if v, ok := req.Annotations[testAnnotationKey]; !ok || string(v) != "1" {
|
|
||||||
return nil, errors.New("invalid version in annotations")
|
|
||||||
}
|
|
||||||
return base64.StdEncoding.DecodeString(string(req.Ciphertext))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testRemoteService) Status(ctx context.Context) (*service.StatusResponse, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
resp := &service.StatusResponse{
|
|
||||||
Version: "v2alpha1",
|
|
||||||
Healthz: "ok",
|
|
||||||
KeyID: s.keyID,
|
|
||||||
}
|
|
||||||
if s.disabled {
|
|
||||||
resp.Healthz = "remote KMS is disabled"
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testRemoteService) SetDisabledStatus(disabled bool) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.disabled = disabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testRemoteService) SetKeyID(keyID string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.keyID = keyID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testRemoteService) EncryptCallCount() int {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
return s.encryptCallCount
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testRemoteService) DecryptCallCount() int {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
return s.decryptCallCount
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncrypt(t *testing.T) {
|
|
||||||
ctx := testContext(t)
|
|
||||||
remoteKMS := &testRemoteService{keyID: "test-key-id"}
|
|
||||||
localKEKService := NewLocalKEKService(ctx, remoteKMS)
|
|
||||||
|
|
||||||
waitUntilReady(t, localKEKService)
|
|
||||||
|
|
||||||
// local KEK is generated and encryption is successful
|
|
||||||
got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Encrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService)
|
|
||||||
|
|
||||||
// local KEK is used for encryption even when remote KMS is failing
|
|
||||||
remoteKMS.SetDisabledStatus(true)
|
|
||||||
if got, err = localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")); err != nil {
|
|
||||||
t.Fatalf("Encrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncryptError(t *testing.T) {
|
|
||||||
ctx := testContext(t)
|
|
||||||
remoteKMS := &testRemoteService{keyID: "test-key-id"}
|
|
||||||
localKEKService := NewLocalKEKService(ctx, remoteKMS)
|
|
||||||
|
|
||||||
// first time local KEK generation fails because of remote KMS
|
|
||||||
remoteKMS.SetDisabledStatus(true)
|
|
||||||
_, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("Encrypt() error = %v, want non-nil", err)
|
|
||||||
}
|
|
||||||
lk := localKEKService.getLocalKEK()
|
|
||||||
if lk.transformer != nil {
|
|
||||||
t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", lk)
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteKMS.SetDisabledStatus(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDecrypt(t *testing.T) {
|
|
||||||
ctx := testContext(t)
|
|
||||||
remoteKMS := &testRemoteService{keyID: "test-key-id"}
|
|
||||||
localKEKService := NewLocalKEKService(ctx, remoteKMS)
|
|
||||||
|
|
||||||
waitUntilReady(t, localKEKService)
|
|
||||||
|
|
||||||
// local KEK is generated and encryption/decryption is successful
|
|
||||||
got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Encrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
if string(got.Ciphertext) == "test-plaintext" {
|
|
||||||
t.Fatalf("Encrypt() ciphertext = %v, want it to be encrypted", got.Ciphertext)
|
|
||||||
}
|
|
||||||
decryptRequest := &service.DecryptRequest{
|
|
||||||
Ciphertext: got.Ciphertext,
|
|
||||||
Annotations: got.Annotations,
|
|
||||||
KeyID: got.KeyID,
|
|
||||||
}
|
|
||||||
plaintext, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Decrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
if string(plaintext) != "test-plaintext" {
|
|
||||||
t.Fatalf("Decrypt() plaintext = %v, want %v", string(plaintext), "test-plaintext")
|
|
||||||
}
|
|
||||||
|
|
||||||
// local KEK is used for decryption even when remote KMS is failing
|
|
||||||
remoteKMS.SetDisabledStatus(true)
|
|
||||||
if _, err = localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
|
|
||||||
t.Fatalf("Decrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDecryptError(t *testing.T) {
|
|
||||||
ctx := testContext(t)
|
|
||||||
remoteKMS := &testRemoteService{keyID: "test-key-id"}
|
|
||||||
localKEKService := NewLocalKEKService(ctx, remoteKMS)
|
|
||||||
|
|
||||||
waitUntilReady(t, localKEKService)
|
|
||||||
|
|
||||||
got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Encrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
decryptRequest := &service.DecryptRequest{
|
|
||||||
Ciphertext: got.Ciphertext,
|
|
||||||
Annotations: got.Annotations,
|
|
||||||
KeyID: got.KeyID,
|
|
||||||
}
|
|
||||||
// local KEK for decryption not in cache and remote KMS is failing
|
|
||||||
remoteKMS.SetDisabledStatus(true)
|
|
||||||
lk := localKEKService.localKEKTracker.Load()
|
|
||||||
lk.transformer = nil
|
|
||||||
localKEKService.localKEKTracker.Store(lk)
|
|
||||||
|
|
||||||
// clear the cache
|
|
||||||
localKEKService.transformers.Clear()
|
|
||||||
if _, err = localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err == nil {
|
|
||||||
t.Fatalf("Decrypt() error = %v, want non-nil", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStatus(t *testing.T) {
|
|
||||||
ctx := testContext(t)
|
|
||||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
|
||||||
remoteKMS := &testRemoteService{keyID: "test-key-id"}
|
|
||||||
localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Second, 100*time.Millisecond, fakeClock)
|
|
||||||
|
|
||||||
waitUntilReady(t, localKEKService)
|
|
||||||
|
|
||||||
got, err := localKEKService.Status(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Status() error = %v", err)
|
|
||||||
}
|
|
||||||
validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id")
|
|
||||||
|
|
||||||
fakeClock.Step(1 * time.Second)
|
|
||||||
// remote KMS is failing
|
|
||||||
remoteKMS.SetDisabledStatus(true)
|
|
||||||
// remote KMS keyID changed but local KEK not rotated because of remote KMS failure
|
|
||||||
// the keyID in status should be the old keyID
|
|
||||||
// the error should still be nil
|
|
||||||
remoteKMS.SetKeyID("test-key-id-2")
|
|
||||||
|
|
||||||
if got, err = localKEKService.Status(ctx); err != nil {
|
|
||||||
t.Fatalf("Status() error = %v, want nil", err)
|
|
||||||
}
|
|
||||||
validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled", "test-key-id")
|
|
||||||
|
|
||||||
fakeClock.Step(1 * time.Second)
|
|
||||||
// wait for local KEK to expire and local KEK service ready to be false
|
|
||||||
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
|
|
||||||
return !localKEKService.isReady.Load(), nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// status response should include the localKEK unhealthy status
|
|
||||||
if got, err = localKEKService.Status(ctx); err != nil {
|
|
||||||
t.Fatalf("Status() error = %v, want nil", err)
|
|
||||||
}
|
|
||||||
validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled; localKEKService is not ready", "test-key-id")
|
|
||||||
|
|
||||||
// remote KMS is functional again, local KEK is rotated
|
|
||||||
remoteKMS.SetDisabledStatus(false)
|
|
||||||
fakeClock.Step(1 * time.Second)
|
|
||||||
waitUntilReady(t, localKEKService)
|
|
||||||
if got, err = localKEKService.Status(ctx); err != nil {
|
|
||||||
t.Fatalf("Status() error = %v, want nil", err)
|
|
||||||
}
|
|
||||||
validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id-2")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRotationKeyUsage(t *testing.T) {
|
|
||||||
ctx := testContext(t)
|
|
||||||
|
|
||||||
var record sync.Map
|
|
||||||
|
|
||||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
|
||||||
remoteKMS := &testRemoteService{keyID: "test-key-id"}
|
|
||||||
localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock)
|
|
||||||
waitUntilReady(t, localKEKService)
|
|
||||||
lk := localKEKService.localKEKTracker.Load()
|
|
||||||
encLocalKEK := lk.encKEK
|
|
||||||
|
|
||||||
// check only single call for Encrypt to remote KMS
|
|
||||||
if remoteKMS.EncryptCallCount() != 1 {
|
|
||||||
t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 6; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32)))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Encrypt() error = %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
|
|
||||||
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record.Store(resp, nil)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if t.Failed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fakeClock.Step(30 * time.Second)
|
|
||||||
rotated := false
|
|
||||||
// wait for the local KEK to be rotated
|
|
||||||
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
|
|
||||||
// local KEK must have been rotated after 5 usages
|
|
||||||
lk = localKEKService.localKEKTracker.Load()
|
|
||||||
rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
|
|
||||||
return rotated, nil
|
|
||||||
})
|
|
||||||
if !rotated {
|
|
||||||
t.Fatalf("local KEK must have been rotated")
|
|
||||||
}
|
|
||||||
if remoteKMS.EncryptCallCount() != 2 {
|
|
||||||
t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// new local KEK must be used for encryption
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32)))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Encrypt() error = %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
|
|
||||||
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record.Store(resp, nil)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if t.Failed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check we can decrypt data encrypted with the old and new local KEKs
|
|
||||||
record.Range(func(key, _ any) bool {
|
|
||||||
k := key.(*service.EncryptResponse)
|
|
||||||
decryptRequest := &service.DecryptRequest{
|
|
||||||
Ciphertext: k.Ciphertext,
|
|
||||||
Annotations: k.Annotations,
|
|
||||||
KeyID: k.KeyID,
|
|
||||||
}
|
|
||||||
if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
|
|
||||||
t.Fatalf("Decrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
// Out of the 11 calls to Decrypt:
|
|
||||||
// - 5 should be using the current local KEK
|
|
||||||
// - 1 out of the 6 should generate a decrypt call to the remote KMS as the local KEK not in cache
|
|
||||||
// - 5 out of the 6 should use the cached local KEK after 1st decrypt call to the remote KMS
|
|
||||||
assertCallCount(t, remoteKMS, localKEKService)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRotationKeyExpiry(t *testing.T) {
|
|
||||||
ctx := testContext(t)
|
|
||||||
|
|
||||||
var record sync.Map
|
|
||||||
|
|
||||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
|
||||||
remoteKMS := &testRemoteService{keyID: "test-key-id"}
|
|
||||||
localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 5*time.Second, 100*time.Millisecond, fakeClock)
|
|
||||||
waitUntilReady(t, localKEKService)
|
|
||||||
lk := localKEKService.localKEKTracker.Load()
|
|
||||||
encLocalKEK := lk.encKEK
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Encrypt() error = %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
|
|
||||||
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record.Store(resp, nil)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if t.Failed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check local KEK has only been used 3 times and still under the suggested usage
|
|
||||||
if lk.usage.Load() != 3 {
|
|
||||||
t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// advance the clock to trigger key expiry
|
|
||||||
fakeClock.Step(6 * time.Second)
|
|
||||||
|
|
||||||
rotated := false
|
|
||||||
// wait for the local KEK to be rotated due to key expiry
|
|
||||||
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
|
|
||||||
// local KEK must have been rotated after the key max age
|
|
||||||
t.Log("waiting for local KEK to be rotated")
|
|
||||||
lk = localKEKService.localKEKTracker.Load()
|
|
||||||
rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
|
|
||||||
return rotated, nil
|
|
||||||
})
|
|
||||||
if !rotated {
|
|
||||||
t.Fatalf("local KEK must have been rotated")
|
|
||||||
}
|
|
||||||
|
|
||||||
// new local KEK must be used for encryption
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Encrypt() error = %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
|
|
||||||
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record.Store(resp, nil)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if t.Failed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check we can decrypt data encrypted with the old and new local KEKs
|
|
||||||
record.Range(func(key, _ any) bool {
|
|
||||||
k := key.(*service.EncryptResponse)
|
|
||||||
decryptRequest := &service.DecryptRequest{
|
|
||||||
Ciphertext: k.Ciphertext,
|
|
||||||
Annotations: k.Annotations,
|
|
||||||
KeyID: k.KeyID,
|
|
||||||
}
|
|
||||||
if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
|
|
||||||
t.Fatalf("Decrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
// Out of the 8 calls to Decrypt:
|
|
||||||
// - 5 should be using the current local KEK
|
|
||||||
// - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache
|
|
||||||
// - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS
|
|
||||||
assertCallCount(t, remoteKMS, localKEKService)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRotationRemoteKeyIDChanged(t *testing.T) {
|
|
||||||
ctx := testContext(t)
|
|
||||||
|
|
||||||
var record sync.Map
|
|
||||||
|
|
||||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
|
||||||
remoteKMS := &testRemoteService{keyID: "test-key-id"}
|
|
||||||
localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock)
|
|
||||||
waitUntilReady(t, localKEKService)
|
|
||||||
lk := localKEKService.localKEKTracker.Load()
|
|
||||||
encLocalKEK := lk.encKEK
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Encrypt() error = %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
|
|
||||||
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record.Store(resp, nil)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if t.Failed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check local KEK has only been used 3 times and still under the suggested usage
|
|
||||||
if lk.usage.Load() != 3 {
|
|
||||||
t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
fakeClock.Step(30 * time.Second)
|
|
||||||
// change the remote key ID
|
|
||||||
remoteKMS.SetKeyID("test-key-id-2")
|
|
||||||
if _, err := localKEKService.Status(ctx); err != nil {
|
|
||||||
t.Fatalf("Status() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rotated := false
|
|
||||||
// wait for the local KEK to be rotated due to remote key ID change
|
|
||||||
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
|
|
||||||
lk = localKEKService.localKEKTracker.Load()
|
|
||||||
rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
|
|
||||||
return rotated, nil
|
|
||||||
})
|
|
||||||
if !rotated {
|
|
||||||
t.Fatalf("local KEK must have been rotated")
|
|
||||||
}
|
|
||||||
|
|
||||||
// new local KEK must be used for encryption
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Encrypt() error = %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
|
|
||||||
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record.Store(resp, nil)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if t.Failed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check we can decrypt data encrypted with the old and new local KEKs
|
|
||||||
record.Range(func(key, _ any) bool {
|
|
||||||
k := key.(*service.EncryptResponse)
|
|
||||||
decryptRequest := &service.DecryptRequest{
|
|
||||||
Ciphertext: k.Ciphertext,
|
|
||||||
Annotations: k.Annotations,
|
|
||||||
KeyID: k.KeyID,
|
|
||||||
}
|
|
||||||
if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
|
|
||||||
t.Fatalf("Decrypt() error = %v", err)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
// Out of the 8 calls to Decrypt:
|
|
||||||
// - 5 should be using the current local KEK
|
|
||||||
// - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache
|
|
||||||
// - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS
|
|
||||||
assertCallCount(t, remoteKMS, localKEKService)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testContext(t *testing.T) context.Context {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
t.Cleanup(cancel)
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitUntilReady(t *testing.T, s *LocalKEKService) {
|
|
||||||
t.Helper()
|
|
||||||
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
|
|
||||||
return s.isReady.Load(), nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateEncryptResponse(t *testing.T, got *service.EncryptResponse, wantKeyID string, localKEKService *LocalKEKService) {
|
|
||||||
t.Helper()
|
|
||||||
if len(got.Annotations) != 2 {
|
|
||||||
t.Fatalf("Encrypt() annotations = %v, want 2 annotations", got.Annotations)
|
|
||||||
}
|
|
||||||
if _, ok := got.Annotations[referenceKEKAnnotationKey]; !ok {
|
|
||||||
t.Fatalf("Encrypt() annotations = %v, want %v", got.Annotations, referenceKEKAnnotationKey)
|
|
||||||
}
|
|
||||||
if got.KeyID != wantKeyID {
|
|
||||||
t.Fatalf("Encrypt() keyID = %v, want %v", got.KeyID, wantKeyID)
|
|
||||||
}
|
|
||||||
if localKEKService.localKEKTracker.Load() == nil {
|
|
||||||
t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", localKEKService.localKEKTracker.Load())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateStatusResponse(t *testing.T, got *service.StatusResponse, wantVersion, wantHealthz, wantKeyID string) {
|
|
||||||
t.Helper()
|
|
||||||
if got.Version != wantVersion {
|
|
||||||
t.Fatalf("Status() version = %v, want %v", got.Version, wantVersion)
|
|
||||||
}
|
|
||||||
if !strings.EqualFold(got.Healthz, wantHealthz) {
|
|
||||||
t.Fatalf("Status() healthz = %v, want %v", got.Healthz, wantHealthz)
|
|
||||||
}
|
|
||||||
if got.KeyID != wantKeyID {
|
|
||||||
t.Fatalf("Status() keyID = %v, want %v", got.KeyID, wantKeyID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertCallCount(t *testing.T, remoteKMS *testRemoteService, localKEKService *LocalKEKService) {
|
|
||||||
t.Helper()
|
|
||||||
if remoteKMS.DecryptCallCount() != 1 {
|
|
||||||
t.Fatalf("Decrypt() remoteKMS.DecryptCallCount() = %v, want %v", remoteKMS.DecryptCallCount(), 1)
|
|
||||||
}
|
|
||||||
if localKEKService.transformers.Len() != 1 {
|
|
||||||
t.Fatalf("Decrypt() localKEKService.transformers.Len() = %v, want %v", localKEKService.transformers.Len(), 1)
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user