From dc41c91a073e9224779b8fe1a293ccaf782a5bc9 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Thu, 7 Nov 2024 12:15:52 -0500 Subject: [PATCH] Move broadcast of key updates into sync --- .../externaljwt/plugin/keycache.go | 60 +++++---- .../externaljwt/plugin/keycache_test.go | 115 +++++++++++++++++- .../externaljwt/plugin/plugin_test.go | 2 +- .../external_jwt_signer_test.go | 4 +- 4 files changed, 151 insertions(+), 30 deletions(-) diff --git a/pkg/serviceaccount/externaljwt/plugin/keycache.go b/pkg/serviceaccount/externaljwt/plugin/keycache.go index 546b42f50f8..521e0f881e7 100644 --- a/pkg/serviceaccount/externaljwt/plugin/keycache.go +++ b/pkg/serviceaccount/externaljwt/plugin/keycache.go @@ -25,10 +25,10 @@ import ( "time" "golang.org/x/sync/singleflight" - "k8s.io/klog/v2" - "k8s.io/kubernetes/pkg/serviceaccount" externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1" + "k8s.io/klog/v2" + "k8s.io/kubernetes/pkg/serviceaccount" externaljwtmetrics "k8s.io/kubernetes/pkg/serviceaccount/externaljwt/metrics" ) @@ -56,7 +56,7 @@ func newKeyCache(client externaljwtv1alpha1.ExternalJWTSignerClient) *keyCache { // InitialFill can be used to perform an initial fetch for keys get the // refresh interval as recommended by external signer. func (p *keyCache) initialFill(ctx context.Context) error { - if _, err := p.syncKeys(ctx); err != nil { + if err := p.syncKeys(ctx); err != nil { return fmt.Errorf("while performing initial cache fill: %w", err) } return nil @@ -66,7 +66,6 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio timer := time.NewTimer(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now())) defer timer.Stop() - var lastDataTimestamp time.Time for { select { case <-ctx.Done(): @@ -76,16 +75,11 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio } timedCtx, cancel := context.WithTimeout(ctx, keySyncTimeout) - dataTimestamp, err := p.syncKeys(timedCtx) - if err != nil { + if err := p.syncKeys(timedCtx); err != nil { klog.Errorf("when syncing supported public keys(Stale set of keys will be supported): %v", err) timer.Reset(fallbackRefreshDuration) } else { timer.Reset(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now())) - if lastDataTimestamp.IsZero() || !dataTimestamp.Equal(lastDataTimestamp) { - lastDataTimestamp = dataTimestamp - p.broadcastUpdate() - } } cancel() } @@ -115,7 +109,7 @@ func (p *keyCache) GetPublicKeys(ctx context.Context, keyID string) []serviceacc } // If we didn't find it, trigger a sync. - if _, err := p.syncKeys(ctx); err != nil { + if err := p.syncKeys(ctx); err != nil { klog.ErrorS(err, "Error while syncing keys") return []serviceaccount.PublicKey{} } @@ -152,8 +146,9 @@ func (p *keyCache) findKeyForKeyID(keyID string) ([]serviceaccount.PublicKey, bo // sync supported external keys. // completely re-writes the set of supported keys. -func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) { - val, err, _ := p.syncGroup.Do("", func() (any, error) { +func (p *keyCache) syncKeys(ctx context.Context) error { + _, err, _ := p.syncGroup.Do("", func() (any, error) { + oldPublicKeys := p.verificationKeys.Load() newPublicKeys, err := p.getTokenVerificationKeys(ctx) externaljwtmetrics.RecordFetchKeysAttempt(err) if err != nil { @@ -161,18 +156,38 @@ func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) { } p.verificationKeys.Store(newPublicKeys) - externaljwtmetrics.RecordKeyDataTimeStamp(newPublicKeys.DataTimestamp.Unix()) - return newPublicKeys, nil + if keysChanged(oldPublicKeys, newPublicKeys) { + p.broadcastUpdate() + } + + return nil, nil }) - if err != nil { - return time.Time{}, err + return err +} + +// keysChanged returns true if the data timestamp, key count, order of key ids or excludeFromOIDCDiscovery indicators +func keysChanged(oldPublicKeys, newPublicKeys *VerificationKeys) bool { + // If the timestamp changed, we changed + if !oldPublicKeys.DataTimestamp.Equal(newPublicKeys.DataTimestamp) { + return true } - - vk := val.(*VerificationKeys) - - return vk.DataTimestamp, nil + // Avoid deepequal checks on key content itself. + // If the number of keys changed, we changed + if len(oldPublicKeys.Keys) != len(newPublicKeys.Keys) { + return true + } + // If the order, key id, or oidc discovery flag changed, we changed. + for i := range oldPublicKeys.Keys { + if oldPublicKeys.Keys[i].KeyID != newPublicKeys.Keys[i].KeyID { + return true + } + if oldPublicKeys.Keys[i].ExcludeFromOIDCDiscovery != newPublicKeys.Keys[i].ExcludeFromOIDCDiscovery { + return true + } + } + return false } func (p *keyCache) broadcastUpdate() { @@ -180,7 +195,8 @@ func (p *keyCache) broadcastUpdate() { defer p.listenersLock.Unlock() for _, l := range p.listeners { - l.Enqueue() + // don't block on a slow listener + go l.Enqueue() } } diff --git a/pkg/serviceaccount/externaljwt/plugin/keycache_test.go b/pkg/serviceaccount/externaljwt/plugin/keycache_test.go index dddb4d3d44e..710e0741159 100644 --- a/pkg/serviceaccount/externaljwt/plugin/keycache_test.go +++ b/pkg/serviceaccount/externaljwt/plugin/keycache_test.go @@ -22,6 +22,7 @@ import ( "net" "os" "strings" + "sync/atomic" "testing" "time" @@ -31,6 +32,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/types/known/timestamppb" + "k8s.io/apimachinery/pkg/util/wait" externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1" "k8s.io/kubernetes/pkg/serviceaccount" ) @@ -167,7 +169,7 @@ func TestExternalPublicKeyGetter(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { ctx := context.Background() - sockname := fmt.Sprintf("@test-external-public-key-getter-%d.sock", i) + sockname := fmt.Sprintf("@test-external-public-key-getter-%d-%d.sock", time.Now().Nanosecond(), i) t.Cleanup(func() { _ = os.Remove(sockname) }) addr := &net.UnixAddr{Name: sockname, Net: "unix"} @@ -238,7 +240,7 @@ func TestExternalPublicKeyGetter(t *testing.T) { func TestInitialFill(t *testing.T) { ctx := context.Background() - sockname := "@test-initial-fill.sock" + sockname := fmt.Sprintf("@test-initial-fill-%d.sock", time.Now().Nanosecond()) t.Cleanup(func() { _ = os.Remove(sockname) }) addr := &net.UnixAddr{Name: sockname, Net: "unix"} @@ -304,7 +306,7 @@ func TestInitialFill(t *testing.T) { func TestReflectChanges(t *testing.T) { ctx := context.Background() - sockname := "@test-reflect-changes.sock" + sockname := fmt.Sprintf("@test-reflect-changes-%d.sock", time.Now().Nanosecond()) t.Cleanup(func() { _ = os.Remove(sockname) }) addr := &net.UnixAddr{Name: sockname, Net: "unix"} @@ -357,18 +359,25 @@ func TestReflectChanges(t *testing.T) { plugin := newPlugin("iss", clientConn, true) + dummyListener := &dummyListener{} + plugin.keyCache.AddListener(dummyListener) + + dummyListener.waitForCount(t, 0) if err := plugin.keyCache.initialFill(ctx); err != nil { t.Fatalf("Error during InitialFill: %v", err) } + dummyListener.waitForCount(t, 1) gotPubKeysT1 := plugin.keyCache.GetPublicKeys(ctx, "") if diff := cmp.Diff(gotPubKeysT1, wantPubKeysT1, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" { t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff) } - if _, err := plugin.keyCache.syncKeys(ctx); err != nil { + dummyListener.waitForCount(t, 1) + if err := plugin.keyCache.syncKeys(ctx); err != nil { t.Fatalf("Error while calling syncKeys: %v", err) } + dummyListener.waitForCount(t, 1) supportedKeysT2 := map[string]supportedKeyT{ "key-1": { @@ -396,12 +405,108 @@ func TestReflectChanges(t *testing.T) { backend.supportedKeys = supportedKeysT2 backend.keyLock.Unlock() - if _, err := plugin.keyCache.syncKeys(ctx); err != nil { + dummyListener.waitForCount(t, 1) + if err := plugin.keyCache.syncKeys(ctx); err != nil { t.Fatalf("Error while calling syncKeys: %v", err) } + dummyListener.waitForCount(t, 2) gotPubKeysT2 := plugin.keyCache.GetPublicKeys(ctx, "") if diff := cmp.Diff(gotPubKeysT2, wantPubKeysT2, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" { t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff) } + dummyListener.waitForCount(t, 2) +} + +type dummyListener struct { + count atomic.Int64 +} + +func (d *dummyListener) waitForCount(t *testing.T, expect int) { + t.Helper() + err := wait.PollUntilContextTimeout(context.Background(), time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) { + actual := int(d.count.Load()) + switch { + case actual > expect: + return false, fmt.Errorf("expected %d broadcasts, got %d broadcasts", expect, actual) + case actual == expect: + return true, nil + default: + t.Logf("expected %d broadcasts, got %d broadcasts, waiting...", expect, actual) + return false, nil + } + }) + if err != nil { + t.Fatal(err) + } +} + +func (d *dummyListener) Enqueue() { + d.count.Add(1) +} + +func TestKeysChanged(t *testing.T) { + testcases := []struct { + name string + oldKeys VerificationKeys + newKeys VerificationKeys + expect bool + }{ + { + name: "empty", + oldKeys: VerificationKeys{}, + newKeys: VerificationKeys{}, + expect: false, + }, + { + name: "identical", + oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}}, + newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}}, + expect: false, + }, + { + name: "changed datatimestamp", + oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}}, + newKeys: VerificationKeys{DataTimestamp: time.Unix(1001, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}}, + expect: true, + }, + { + name: "reordered keyid", + oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}}, + newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}, {KeyID: "a"}}}, + expect: true, + }, + { + name: "changed keyid", + oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}}, + newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}}}, + expect: true, + }, + { + name: "added key", + oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}}, + newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}}, + expect: true, + }, + { + name: "removed key", + oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}}, + newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}}, + expect: true, + }, + { + name: "changed oidc", + oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: false}}}, + newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: true}}}, + expect: true, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + result := keysChanged(&tc.oldKeys, &tc.newKeys) + if result != tc.expect { + t.Errorf("got %v, expected %v", result, tc.expect) + } + }) + } } diff --git a/pkg/serviceaccount/externaljwt/plugin/plugin_test.go b/pkg/serviceaccount/externaljwt/plugin/plugin_test.go index 20a4f648a4b..106a42d7769 100644 --- a/pkg/serviceaccount/externaljwt/plugin/plugin_test.go +++ b/pkg/serviceaccount/externaljwt/plugin/plugin_test.go @@ -258,7 +258,7 @@ func TestExternalTokenGenerator(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { ctx := context.Background() - sockname := fmt.Sprintf("@test-external-token-generator-%d.sock", i) + sockname := fmt.Sprintf("@test-external-token-generator-%d-%d.sock", time.Now().Nanosecond(), i) t.Cleanup(func() { _ = os.Remove(sockname) }) addr := &net.UnixAddr{Name: sockname, Net: "unix"} diff --git a/test/integration/serviceaccount/external_jwt_signer_test.go b/test/integration/serviceaccount/external_jwt_signer_test.go index 8feb3065091..4650d7feb82 100644 --- a/test/integration/serviceaccount/external_jwt_signer_test.go +++ b/test/integration/serviceaccount/external_jwt_signer_test.go @@ -59,7 +59,7 @@ func TestExternalJWTSigningAndAuth(t *testing.T) { defer cancel() // create and start mock signer. - socketPath := "@mock-external-jwt-signer.sock" + socketPath := fmt.Sprintf("@mock-external-jwt-signer-%d.sock", time.Now().Nanosecond()) t.Cleanup(func() { _ = os.Remove(socketPath) }) mockSigner := v1alpha1testing.NewMockSigner(t, socketPath) defer mockSigner.CleanUp() @@ -227,7 +227,7 @@ func TestExternalJWTSigningAndAuth(t *testing.T) { } if !tokenReviewResult.Status.Authenticated && tc.shouldPassAuth { - t.Fatal("Expected Authentication to succeed") + t.Fatalf("Expected Authentication to succeed, got %v", tokenReviewResult.Status.Error) } else if tokenReviewResult.Status.Authenticated && !tc.shouldPassAuth { t.Fatal("Expected Authentication to fail") }