Move broadcast of key updates into sync

This commit is contained in:
Jordan Liggitt 2024-11-07 12:15:52 -05:00
parent 33c64b380a
commit dc41c91a07
No known key found for this signature in database
4 changed files with 151 additions and 30 deletions

View File

@ -25,10 +25,10 @@ import (
"time"
"golang.org/x/sync/singleflight"
"k8s.io/klog/v2"
"k8s.io/kubernetes/pkg/serviceaccount"
externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1"
"k8s.io/klog/v2"
"k8s.io/kubernetes/pkg/serviceaccount"
externaljwtmetrics "k8s.io/kubernetes/pkg/serviceaccount/externaljwt/metrics"
)
@ -56,7 +56,7 @@ func newKeyCache(client externaljwtv1alpha1.ExternalJWTSignerClient) *keyCache {
// InitialFill can be used to perform an initial fetch for keys get the
// refresh interval as recommended by external signer.
func (p *keyCache) initialFill(ctx context.Context) error {
if _, err := p.syncKeys(ctx); err != nil {
if err := p.syncKeys(ctx); err != nil {
return fmt.Errorf("while performing initial cache fill: %w", err)
}
return nil
@ -66,7 +66,6 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio
timer := time.NewTimer(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now()))
defer timer.Stop()
var lastDataTimestamp time.Time
for {
select {
case <-ctx.Done():
@ -76,16 +75,11 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio
}
timedCtx, cancel := context.WithTimeout(ctx, keySyncTimeout)
dataTimestamp, err := p.syncKeys(timedCtx)
if err != nil {
if err := p.syncKeys(timedCtx); err != nil {
klog.Errorf("when syncing supported public keys(Stale set of keys will be supported): %v", err)
timer.Reset(fallbackRefreshDuration)
} else {
timer.Reset(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now()))
if lastDataTimestamp.IsZero() || !dataTimestamp.Equal(lastDataTimestamp) {
lastDataTimestamp = dataTimestamp
p.broadcastUpdate()
}
}
cancel()
}
@ -115,7 +109,7 @@ func (p *keyCache) GetPublicKeys(ctx context.Context, keyID string) []serviceacc
}
// If we didn't find it, trigger a sync.
if _, err := p.syncKeys(ctx); err != nil {
if err := p.syncKeys(ctx); err != nil {
klog.ErrorS(err, "Error while syncing keys")
return []serviceaccount.PublicKey{}
}
@ -152,8 +146,9 @@ func (p *keyCache) findKeyForKeyID(keyID string) ([]serviceaccount.PublicKey, bo
// sync supported external keys.
// completely re-writes the set of supported keys.
func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) {
val, err, _ := p.syncGroup.Do("", func() (any, error) {
func (p *keyCache) syncKeys(ctx context.Context) error {
_, err, _ := p.syncGroup.Do("", func() (any, error) {
oldPublicKeys := p.verificationKeys.Load()
newPublicKeys, err := p.getTokenVerificationKeys(ctx)
externaljwtmetrics.RecordFetchKeysAttempt(err)
if err != nil {
@ -161,18 +156,38 @@ func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) {
}
p.verificationKeys.Store(newPublicKeys)
externaljwtmetrics.RecordKeyDataTimeStamp(newPublicKeys.DataTimestamp.Unix())
return newPublicKeys, nil
if keysChanged(oldPublicKeys, newPublicKeys) {
p.broadcastUpdate()
}
return nil, nil
})
if err != nil {
return time.Time{}, err
return err
}
// keysChanged returns true if the data timestamp, key count, order of key ids or excludeFromOIDCDiscovery indicators
func keysChanged(oldPublicKeys, newPublicKeys *VerificationKeys) bool {
// If the timestamp changed, we changed
if !oldPublicKeys.DataTimestamp.Equal(newPublicKeys.DataTimestamp) {
return true
}
vk := val.(*VerificationKeys)
return vk.DataTimestamp, nil
// Avoid deepequal checks on key content itself.
// If the number of keys changed, we changed
if len(oldPublicKeys.Keys) != len(newPublicKeys.Keys) {
return true
}
// If the order, key id, or oidc discovery flag changed, we changed.
for i := range oldPublicKeys.Keys {
if oldPublicKeys.Keys[i].KeyID != newPublicKeys.Keys[i].KeyID {
return true
}
if oldPublicKeys.Keys[i].ExcludeFromOIDCDiscovery != newPublicKeys.Keys[i].ExcludeFromOIDCDiscovery {
return true
}
}
return false
}
func (p *keyCache) broadcastUpdate() {
@ -180,7 +195,8 @@ func (p *keyCache) broadcastUpdate() {
defer p.listenersLock.Unlock()
for _, l := range p.listeners {
l.Enqueue()
// don't block on a slow listener
go l.Enqueue()
}
}

View File

@ -22,6 +22,7 @@ import (
"net"
"os"
"strings"
"sync/atomic"
"testing"
"time"
@ -31,6 +32,7 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/timestamppb"
"k8s.io/apimachinery/pkg/util/wait"
externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1"
"k8s.io/kubernetes/pkg/serviceaccount"
)
@ -167,7 +169,7 @@ func TestExternalPublicKeyGetter(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.Background()
sockname := fmt.Sprintf("@test-external-public-key-getter-%d.sock", i)
sockname := fmt.Sprintf("@test-external-public-key-getter-%d-%d.sock", time.Now().Nanosecond(), i)
t.Cleanup(func() { _ = os.Remove(sockname) })
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@ -238,7 +240,7 @@ func TestExternalPublicKeyGetter(t *testing.T) {
func TestInitialFill(t *testing.T) {
ctx := context.Background()
sockname := "@test-initial-fill.sock"
sockname := fmt.Sprintf("@test-initial-fill-%d.sock", time.Now().Nanosecond())
t.Cleanup(func() { _ = os.Remove(sockname) })
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@ -304,7 +306,7 @@ func TestInitialFill(t *testing.T) {
func TestReflectChanges(t *testing.T) {
ctx := context.Background()
sockname := "@test-reflect-changes.sock"
sockname := fmt.Sprintf("@test-reflect-changes-%d.sock", time.Now().Nanosecond())
t.Cleanup(func() { _ = os.Remove(sockname) })
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@ -357,18 +359,25 @@ func TestReflectChanges(t *testing.T) {
plugin := newPlugin("iss", clientConn, true)
dummyListener := &dummyListener{}
plugin.keyCache.AddListener(dummyListener)
dummyListener.waitForCount(t, 0)
if err := plugin.keyCache.initialFill(ctx); err != nil {
t.Fatalf("Error during InitialFill: %v", err)
}
dummyListener.waitForCount(t, 1)
gotPubKeysT1 := plugin.keyCache.GetPublicKeys(ctx, "")
if diff := cmp.Diff(gotPubKeysT1, wantPubKeysT1, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" {
t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff)
}
if _, err := plugin.keyCache.syncKeys(ctx); err != nil {
dummyListener.waitForCount(t, 1)
if err := plugin.keyCache.syncKeys(ctx); err != nil {
t.Fatalf("Error while calling syncKeys: %v", err)
}
dummyListener.waitForCount(t, 1)
supportedKeysT2 := map[string]supportedKeyT{
"key-1": {
@ -396,12 +405,108 @@ func TestReflectChanges(t *testing.T) {
backend.supportedKeys = supportedKeysT2
backend.keyLock.Unlock()
if _, err := plugin.keyCache.syncKeys(ctx); err != nil {
dummyListener.waitForCount(t, 1)
if err := plugin.keyCache.syncKeys(ctx); err != nil {
t.Fatalf("Error while calling syncKeys: %v", err)
}
dummyListener.waitForCount(t, 2)
gotPubKeysT2 := plugin.keyCache.GetPublicKeys(ctx, "")
if diff := cmp.Diff(gotPubKeysT2, wantPubKeysT2, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" {
t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff)
}
dummyListener.waitForCount(t, 2)
}
type dummyListener struct {
count atomic.Int64
}
func (d *dummyListener) waitForCount(t *testing.T, expect int) {
t.Helper()
err := wait.PollUntilContextTimeout(context.Background(), time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) {
actual := int(d.count.Load())
switch {
case actual > expect:
return false, fmt.Errorf("expected %d broadcasts, got %d broadcasts", expect, actual)
case actual == expect:
return true, nil
default:
t.Logf("expected %d broadcasts, got %d broadcasts, waiting...", expect, actual)
return false, nil
}
})
if err != nil {
t.Fatal(err)
}
}
func (d *dummyListener) Enqueue() {
d.count.Add(1)
}
func TestKeysChanged(t *testing.T) {
testcases := []struct {
name string
oldKeys VerificationKeys
newKeys VerificationKeys
expect bool
}{
{
name: "empty",
oldKeys: VerificationKeys{},
newKeys: VerificationKeys{},
expect: false,
},
{
name: "identical",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
expect: false,
},
{
name: "changed datatimestamp",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1001, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
expect: true,
},
{
name: "reordered keyid",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}, {KeyID: "a"}}},
expect: true,
},
{
name: "changed keyid",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}}},
expect: true,
},
{
name: "added key",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
expect: true,
},
{
name: "removed key",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
expect: true,
},
{
name: "changed oidc",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: false}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: true}}},
expect: true,
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
result := keysChanged(&tc.oldKeys, &tc.newKeys)
if result != tc.expect {
t.Errorf("got %v, expected %v", result, tc.expect)
}
})
}
}

View File

@ -258,7 +258,7 @@ func TestExternalTokenGenerator(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.Background()
sockname := fmt.Sprintf("@test-external-token-generator-%d.sock", i)
sockname := fmt.Sprintf("@test-external-token-generator-%d-%d.sock", time.Now().Nanosecond(), i)
t.Cleanup(func() { _ = os.Remove(sockname) })
addr := &net.UnixAddr{Name: sockname, Net: "unix"}

View File

@ -59,7 +59,7 @@ func TestExternalJWTSigningAndAuth(t *testing.T) {
defer cancel()
// create and start mock signer.
socketPath := "@mock-external-jwt-signer.sock"
socketPath := fmt.Sprintf("@mock-external-jwt-signer-%d.sock", time.Now().Nanosecond())
t.Cleanup(func() { _ = os.Remove(socketPath) })
mockSigner := v1alpha1testing.NewMockSigner(t, socketPath)
defer mockSigner.CleanUp()
@ -227,7 +227,7 @@ func TestExternalJWTSigningAndAuth(t *testing.T) {
}
if !tokenReviewResult.Status.Authenticated && tc.shouldPassAuth {
t.Fatal("Expected Authentication to succeed")
t.Fatalf("Expected Authentication to succeed, got %v", tokenReviewResult.Status.Error)
} else if tokenReviewResult.Status.Authenticated && !tc.shouldPassAuth {
t.Fatal("Expected Authentication to fail")
}