Move broadcast of key updates into sync

This commit is contained in:
Jordan Liggitt 2024-11-07 12:15:52 -05:00
parent 33c64b380a
commit dc41c91a07
No known key found for this signature in database
4 changed files with 151 additions and 30 deletions

View File

@ -25,10 +25,10 @@ import (
"time" "time"
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
"k8s.io/klog/v2"
"k8s.io/kubernetes/pkg/serviceaccount"
externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1" externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1"
"k8s.io/klog/v2"
"k8s.io/kubernetes/pkg/serviceaccount"
externaljwtmetrics "k8s.io/kubernetes/pkg/serviceaccount/externaljwt/metrics" externaljwtmetrics "k8s.io/kubernetes/pkg/serviceaccount/externaljwt/metrics"
) )
@ -56,7 +56,7 @@ func newKeyCache(client externaljwtv1alpha1.ExternalJWTSignerClient) *keyCache {
// InitialFill can be used to perform an initial fetch for keys get the // InitialFill can be used to perform an initial fetch for keys get the
// refresh interval as recommended by external signer. // refresh interval as recommended by external signer.
func (p *keyCache) initialFill(ctx context.Context) error { func (p *keyCache) initialFill(ctx context.Context) error {
if _, err := p.syncKeys(ctx); err != nil { if err := p.syncKeys(ctx); err != nil {
return fmt.Errorf("while performing initial cache fill: %w", err) return fmt.Errorf("while performing initial cache fill: %w", err)
} }
return nil return nil
@ -66,7 +66,6 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio
timer := time.NewTimer(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now())) timer := time.NewTimer(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now()))
defer timer.Stop() defer timer.Stop()
var lastDataTimestamp time.Time
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -76,16 +75,11 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio
} }
timedCtx, cancel := context.WithTimeout(ctx, keySyncTimeout) timedCtx, cancel := context.WithTimeout(ctx, keySyncTimeout)
dataTimestamp, err := p.syncKeys(timedCtx) if err := p.syncKeys(timedCtx); err != nil {
if err != nil {
klog.Errorf("when syncing supported public keys(Stale set of keys will be supported): %v", err) klog.Errorf("when syncing supported public keys(Stale set of keys will be supported): %v", err)
timer.Reset(fallbackRefreshDuration) timer.Reset(fallbackRefreshDuration)
} else { } else {
timer.Reset(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now())) timer.Reset(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now()))
if lastDataTimestamp.IsZero() || !dataTimestamp.Equal(lastDataTimestamp) {
lastDataTimestamp = dataTimestamp
p.broadcastUpdate()
}
} }
cancel() cancel()
} }
@ -115,7 +109,7 @@ func (p *keyCache) GetPublicKeys(ctx context.Context, keyID string) []serviceacc
} }
// If we didn't find it, trigger a sync. // If we didn't find it, trigger a sync.
if _, err := p.syncKeys(ctx); err != nil { if err := p.syncKeys(ctx); err != nil {
klog.ErrorS(err, "Error while syncing keys") klog.ErrorS(err, "Error while syncing keys")
return []serviceaccount.PublicKey{} return []serviceaccount.PublicKey{}
} }
@ -152,8 +146,9 @@ func (p *keyCache) findKeyForKeyID(keyID string) ([]serviceaccount.PublicKey, bo
// sync supported external keys. // sync supported external keys.
// completely re-writes the set of supported keys. // completely re-writes the set of supported keys.
func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) { func (p *keyCache) syncKeys(ctx context.Context) error {
val, err, _ := p.syncGroup.Do("", func() (any, error) { _, err, _ := p.syncGroup.Do("", func() (any, error) {
oldPublicKeys := p.verificationKeys.Load()
newPublicKeys, err := p.getTokenVerificationKeys(ctx) newPublicKeys, err := p.getTokenVerificationKeys(ctx)
externaljwtmetrics.RecordFetchKeysAttempt(err) externaljwtmetrics.RecordFetchKeysAttempt(err)
if err != nil { if err != nil {
@ -161,18 +156,38 @@ func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) {
} }
p.verificationKeys.Store(newPublicKeys) p.verificationKeys.Store(newPublicKeys)
externaljwtmetrics.RecordKeyDataTimeStamp(newPublicKeys.DataTimestamp.Unix()) externaljwtmetrics.RecordKeyDataTimeStamp(newPublicKeys.DataTimestamp.Unix())
return newPublicKeys, nil if keysChanged(oldPublicKeys, newPublicKeys) {
p.broadcastUpdate()
}
return nil, nil
}) })
if err != nil { return err
return time.Time{}, err }
// keysChanged returns true if the data timestamp, key count, order of key ids or excludeFromOIDCDiscovery indicators
func keysChanged(oldPublicKeys, newPublicKeys *VerificationKeys) bool {
// If the timestamp changed, we changed
if !oldPublicKeys.DataTimestamp.Equal(newPublicKeys.DataTimestamp) {
return true
} }
// Avoid deepequal checks on key content itself.
vk := val.(*VerificationKeys) // If the number of keys changed, we changed
if len(oldPublicKeys.Keys) != len(newPublicKeys.Keys) {
return vk.DataTimestamp, nil return true
}
// If the order, key id, or oidc discovery flag changed, we changed.
for i := range oldPublicKeys.Keys {
if oldPublicKeys.Keys[i].KeyID != newPublicKeys.Keys[i].KeyID {
return true
}
if oldPublicKeys.Keys[i].ExcludeFromOIDCDiscovery != newPublicKeys.Keys[i].ExcludeFromOIDCDiscovery {
return true
}
}
return false
} }
func (p *keyCache) broadcastUpdate() { func (p *keyCache) broadcastUpdate() {
@ -180,7 +195,8 @@ func (p *keyCache) broadcastUpdate() {
defer p.listenersLock.Unlock() defer p.listenersLock.Unlock()
for _, l := range p.listeners { for _, l := range p.listeners {
l.Enqueue() // don't block on a slow listener
go l.Enqueue()
} }
} }

View File

@ -22,6 +22,7 @@ import (
"net" "net"
"os" "os"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -31,6 +32,7 @@ import (
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"k8s.io/apimachinery/pkg/util/wait"
externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1" externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1"
"k8s.io/kubernetes/pkg/serviceaccount" "k8s.io/kubernetes/pkg/serviceaccount"
) )
@ -167,7 +169,7 @@ func TestExternalPublicKeyGetter(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
sockname := fmt.Sprintf("@test-external-public-key-getter-%d.sock", i) sockname := fmt.Sprintf("@test-external-public-key-getter-%d-%d.sock", time.Now().Nanosecond(), i)
t.Cleanup(func() { _ = os.Remove(sockname) }) t.Cleanup(func() { _ = os.Remove(sockname) })
addr := &net.UnixAddr{Name: sockname, Net: "unix"} addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@ -238,7 +240,7 @@ func TestExternalPublicKeyGetter(t *testing.T) {
func TestInitialFill(t *testing.T) { func TestInitialFill(t *testing.T) {
ctx := context.Background() ctx := context.Background()
sockname := "@test-initial-fill.sock" sockname := fmt.Sprintf("@test-initial-fill-%d.sock", time.Now().Nanosecond())
t.Cleanup(func() { _ = os.Remove(sockname) }) t.Cleanup(func() { _ = os.Remove(sockname) })
addr := &net.UnixAddr{Name: sockname, Net: "unix"} addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@ -304,7 +306,7 @@ func TestInitialFill(t *testing.T) {
func TestReflectChanges(t *testing.T) { func TestReflectChanges(t *testing.T) {
ctx := context.Background() ctx := context.Background()
sockname := "@test-reflect-changes.sock" sockname := fmt.Sprintf("@test-reflect-changes-%d.sock", time.Now().Nanosecond())
t.Cleanup(func() { _ = os.Remove(sockname) }) t.Cleanup(func() { _ = os.Remove(sockname) })
addr := &net.UnixAddr{Name: sockname, Net: "unix"} addr := &net.UnixAddr{Name: sockname, Net: "unix"}
@ -357,18 +359,25 @@ func TestReflectChanges(t *testing.T) {
plugin := newPlugin("iss", clientConn, true) plugin := newPlugin("iss", clientConn, true)
dummyListener := &dummyListener{}
plugin.keyCache.AddListener(dummyListener)
dummyListener.waitForCount(t, 0)
if err := plugin.keyCache.initialFill(ctx); err != nil { if err := plugin.keyCache.initialFill(ctx); err != nil {
t.Fatalf("Error during InitialFill: %v", err) t.Fatalf("Error during InitialFill: %v", err)
} }
dummyListener.waitForCount(t, 1)
gotPubKeysT1 := plugin.keyCache.GetPublicKeys(ctx, "") gotPubKeysT1 := plugin.keyCache.GetPublicKeys(ctx, "")
if diff := cmp.Diff(gotPubKeysT1, wantPubKeysT1, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" { if diff := cmp.Diff(gotPubKeysT1, wantPubKeysT1, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" {
t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff) t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff)
} }
if _, err := plugin.keyCache.syncKeys(ctx); err != nil { dummyListener.waitForCount(t, 1)
if err := plugin.keyCache.syncKeys(ctx); err != nil {
t.Fatalf("Error while calling syncKeys: %v", err) t.Fatalf("Error while calling syncKeys: %v", err)
} }
dummyListener.waitForCount(t, 1)
supportedKeysT2 := map[string]supportedKeyT{ supportedKeysT2 := map[string]supportedKeyT{
"key-1": { "key-1": {
@ -396,12 +405,108 @@ func TestReflectChanges(t *testing.T) {
backend.supportedKeys = supportedKeysT2 backend.supportedKeys = supportedKeysT2
backend.keyLock.Unlock() backend.keyLock.Unlock()
if _, err := plugin.keyCache.syncKeys(ctx); err != nil { dummyListener.waitForCount(t, 1)
if err := plugin.keyCache.syncKeys(ctx); err != nil {
t.Fatalf("Error while calling syncKeys: %v", err) t.Fatalf("Error while calling syncKeys: %v", err)
} }
dummyListener.waitForCount(t, 2)
gotPubKeysT2 := plugin.keyCache.GetPublicKeys(ctx, "") gotPubKeysT2 := plugin.keyCache.GetPublicKeys(ctx, "")
if diff := cmp.Diff(gotPubKeysT2, wantPubKeysT2, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" { if diff := cmp.Diff(gotPubKeysT2, wantPubKeysT2, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" {
t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff) t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff)
} }
dummyListener.waitForCount(t, 2)
}
type dummyListener struct {
count atomic.Int64
}
func (d *dummyListener) waitForCount(t *testing.T, expect int) {
t.Helper()
err := wait.PollUntilContextTimeout(context.Background(), time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) {
actual := int(d.count.Load())
switch {
case actual > expect:
return false, fmt.Errorf("expected %d broadcasts, got %d broadcasts", expect, actual)
case actual == expect:
return true, nil
default:
t.Logf("expected %d broadcasts, got %d broadcasts, waiting...", expect, actual)
return false, nil
}
})
if err != nil {
t.Fatal(err)
}
}
func (d *dummyListener) Enqueue() {
d.count.Add(1)
}
func TestKeysChanged(t *testing.T) {
testcases := []struct {
name string
oldKeys VerificationKeys
newKeys VerificationKeys
expect bool
}{
{
name: "empty",
oldKeys: VerificationKeys{},
newKeys: VerificationKeys{},
expect: false,
},
{
name: "identical",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
expect: false,
},
{
name: "changed datatimestamp",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1001, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
expect: true,
},
{
name: "reordered keyid",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}, {KeyID: "a"}}},
expect: true,
},
{
name: "changed keyid",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}}},
expect: true,
},
{
name: "added key",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
expect: true,
},
{
name: "removed key",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
expect: true,
},
{
name: "changed oidc",
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: false}}},
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: true}}},
expect: true,
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
result := keysChanged(&tc.oldKeys, &tc.newKeys)
if result != tc.expect {
t.Errorf("got %v, expected %v", result, tc.expect)
}
})
}
} }

View File

@ -258,7 +258,7 @@ func TestExternalTokenGenerator(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
sockname := fmt.Sprintf("@test-external-token-generator-%d.sock", i) sockname := fmt.Sprintf("@test-external-token-generator-%d-%d.sock", time.Now().Nanosecond(), i)
t.Cleanup(func() { _ = os.Remove(sockname) }) t.Cleanup(func() { _ = os.Remove(sockname) })
addr := &net.UnixAddr{Name: sockname, Net: "unix"} addr := &net.UnixAddr{Name: sockname, Net: "unix"}

View File

@ -59,7 +59,7 @@ func TestExternalJWTSigningAndAuth(t *testing.T) {
defer cancel() defer cancel()
// create and start mock signer. // create and start mock signer.
socketPath := "@mock-external-jwt-signer.sock" socketPath := fmt.Sprintf("@mock-external-jwt-signer-%d.sock", time.Now().Nanosecond())
t.Cleanup(func() { _ = os.Remove(socketPath) }) t.Cleanup(func() { _ = os.Remove(socketPath) })
mockSigner := v1alpha1testing.NewMockSigner(t, socketPath) mockSigner := v1alpha1testing.NewMockSigner(t, socketPath)
defer mockSigner.CleanUp() defer mockSigner.CleanUp()
@ -227,7 +227,7 @@ func TestExternalJWTSigningAndAuth(t *testing.T) {
} }
if !tokenReviewResult.Status.Authenticated && tc.shouldPassAuth { if !tokenReviewResult.Status.Authenticated && tc.shouldPassAuth {
t.Fatal("Expected Authentication to succeed") t.Fatalf("Expected Authentication to succeed, got %v", tokenReviewResult.Status.Error)
} else if tokenReviewResult.Status.Authenticated && !tc.shouldPassAuth { } else if tokenReviewResult.Status.Authenticated && !tc.shouldPassAuth {
t.Fatal("Expected Authentication to fail") t.Fatal("Expected Authentication to fail")
} }