mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-30 15:05:27 +00:00
Move broadcast of key updates into sync
This commit is contained in:
parent
33c64b380a
commit
dc41c91a07
@ -25,10 +25,10 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
"k8s.io/klog/v2"
|
||||
"k8s.io/kubernetes/pkg/serviceaccount"
|
||||
|
||||
externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1"
|
||||
"k8s.io/klog/v2"
|
||||
"k8s.io/kubernetes/pkg/serviceaccount"
|
||||
externaljwtmetrics "k8s.io/kubernetes/pkg/serviceaccount/externaljwt/metrics"
|
||||
)
|
||||
|
||||
@ -56,7 +56,7 @@ func newKeyCache(client externaljwtv1alpha1.ExternalJWTSignerClient) *keyCache {
|
||||
// InitialFill can be used to perform an initial fetch for keys get the
|
||||
// refresh interval as recommended by external signer.
|
||||
func (p *keyCache) initialFill(ctx context.Context) error {
|
||||
if _, err := p.syncKeys(ctx); err != nil {
|
||||
if err := p.syncKeys(ctx); err != nil {
|
||||
return fmt.Errorf("while performing initial cache fill: %w", err)
|
||||
}
|
||||
return nil
|
||||
@ -66,7 +66,6 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio
|
||||
timer := time.NewTimer(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now()))
|
||||
defer timer.Stop()
|
||||
|
||||
var lastDataTimestamp time.Time
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -76,16 +75,11 @@ func (p *keyCache) scheduleSync(ctx context.Context, keySyncTimeout time.Duratio
|
||||
}
|
||||
|
||||
timedCtx, cancel := context.WithTimeout(ctx, keySyncTimeout)
|
||||
dataTimestamp, err := p.syncKeys(timedCtx)
|
||||
if err != nil {
|
||||
if err := p.syncKeys(timedCtx); err != nil {
|
||||
klog.Errorf("when syncing supported public keys(Stale set of keys will be supported): %v", err)
|
||||
timer.Reset(fallbackRefreshDuration)
|
||||
} else {
|
||||
timer.Reset(p.verificationKeys.Load().NextRefreshHint.Sub(time.Now()))
|
||||
if lastDataTimestamp.IsZero() || !dataTimestamp.Equal(lastDataTimestamp) {
|
||||
lastDataTimestamp = dataTimestamp
|
||||
p.broadcastUpdate()
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
@ -115,7 +109,7 @@ func (p *keyCache) GetPublicKeys(ctx context.Context, keyID string) []serviceacc
|
||||
}
|
||||
|
||||
// If we didn't find it, trigger a sync.
|
||||
if _, err := p.syncKeys(ctx); err != nil {
|
||||
if err := p.syncKeys(ctx); err != nil {
|
||||
klog.ErrorS(err, "Error while syncing keys")
|
||||
return []serviceaccount.PublicKey{}
|
||||
}
|
||||
@ -152,8 +146,9 @@ func (p *keyCache) findKeyForKeyID(keyID string) ([]serviceaccount.PublicKey, bo
|
||||
|
||||
// sync supported external keys.
|
||||
// completely re-writes the set of supported keys.
|
||||
func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) {
|
||||
val, err, _ := p.syncGroup.Do("", func() (any, error) {
|
||||
func (p *keyCache) syncKeys(ctx context.Context) error {
|
||||
_, err, _ := p.syncGroup.Do("", func() (any, error) {
|
||||
oldPublicKeys := p.verificationKeys.Load()
|
||||
newPublicKeys, err := p.getTokenVerificationKeys(ctx)
|
||||
externaljwtmetrics.RecordFetchKeysAttempt(err)
|
||||
if err != nil {
|
||||
@ -161,18 +156,38 @@ func (p *keyCache) syncKeys(ctx context.Context) (time.Time, error) {
|
||||
}
|
||||
|
||||
p.verificationKeys.Store(newPublicKeys)
|
||||
|
||||
externaljwtmetrics.RecordKeyDataTimeStamp(newPublicKeys.DataTimestamp.Unix())
|
||||
|
||||
return newPublicKeys, nil
|
||||
if keysChanged(oldPublicKeys, newPublicKeys) {
|
||||
p.broadcastUpdate()
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
return err
|
||||
}
|
||||
|
||||
// keysChanged returns true if the data timestamp, key count, order of key ids or excludeFromOIDCDiscovery indicators
|
||||
func keysChanged(oldPublicKeys, newPublicKeys *VerificationKeys) bool {
|
||||
// If the timestamp changed, we changed
|
||||
if !oldPublicKeys.DataTimestamp.Equal(newPublicKeys.DataTimestamp) {
|
||||
return true
|
||||
}
|
||||
|
||||
vk := val.(*VerificationKeys)
|
||||
|
||||
return vk.DataTimestamp, nil
|
||||
// Avoid deepequal checks on key content itself.
|
||||
// If the number of keys changed, we changed
|
||||
if len(oldPublicKeys.Keys) != len(newPublicKeys.Keys) {
|
||||
return true
|
||||
}
|
||||
// If the order, key id, or oidc discovery flag changed, we changed.
|
||||
for i := range oldPublicKeys.Keys {
|
||||
if oldPublicKeys.Keys[i].KeyID != newPublicKeys.Keys[i].KeyID {
|
||||
return true
|
||||
}
|
||||
if oldPublicKeys.Keys[i].ExcludeFromOIDCDiscovery != newPublicKeys.Keys[i].ExcludeFromOIDCDiscovery {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *keyCache) broadcastUpdate() {
|
||||
@ -180,7 +195,8 @@ func (p *keyCache) broadcastUpdate() {
|
||||
defer p.listenersLock.Unlock()
|
||||
|
||||
for _, l := range p.listeners {
|
||||
l.Enqueue()
|
||||
// don't block on a slow listener
|
||||
go l.Enqueue()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -31,6 +32,7 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
externaljwtv1alpha1 "k8s.io/externaljwt/apis/v1alpha1"
|
||||
"k8s.io/kubernetes/pkg/serviceaccount"
|
||||
)
|
||||
@ -167,7 +169,7 @@ func TestExternalPublicKeyGetter(t *testing.T) {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
sockname := fmt.Sprintf("@test-external-public-key-getter-%d.sock", i)
|
||||
sockname := fmt.Sprintf("@test-external-public-key-getter-%d-%d.sock", time.Now().Nanosecond(), i)
|
||||
t.Cleanup(func() { _ = os.Remove(sockname) })
|
||||
|
||||
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
|
||||
@ -238,7 +240,7 @@ func TestExternalPublicKeyGetter(t *testing.T) {
|
||||
func TestInitialFill(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
sockname := "@test-initial-fill.sock"
|
||||
sockname := fmt.Sprintf("@test-initial-fill-%d.sock", time.Now().Nanosecond())
|
||||
t.Cleanup(func() { _ = os.Remove(sockname) })
|
||||
|
||||
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
|
||||
@ -304,7 +306,7 @@ func TestInitialFill(t *testing.T) {
|
||||
func TestReflectChanges(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
sockname := "@test-reflect-changes.sock"
|
||||
sockname := fmt.Sprintf("@test-reflect-changes-%d.sock", time.Now().Nanosecond())
|
||||
t.Cleanup(func() { _ = os.Remove(sockname) })
|
||||
|
||||
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
|
||||
@ -357,18 +359,25 @@ func TestReflectChanges(t *testing.T) {
|
||||
|
||||
plugin := newPlugin("iss", clientConn, true)
|
||||
|
||||
dummyListener := &dummyListener{}
|
||||
plugin.keyCache.AddListener(dummyListener)
|
||||
|
||||
dummyListener.waitForCount(t, 0)
|
||||
if err := plugin.keyCache.initialFill(ctx); err != nil {
|
||||
t.Fatalf("Error during InitialFill: %v", err)
|
||||
}
|
||||
dummyListener.waitForCount(t, 1)
|
||||
|
||||
gotPubKeysT1 := plugin.keyCache.GetPublicKeys(ctx, "")
|
||||
if diff := cmp.Diff(gotPubKeysT1, wantPubKeysT1, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" {
|
||||
t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff)
|
||||
}
|
||||
|
||||
if _, err := plugin.keyCache.syncKeys(ctx); err != nil {
|
||||
dummyListener.waitForCount(t, 1)
|
||||
if err := plugin.keyCache.syncKeys(ctx); err != nil {
|
||||
t.Fatalf("Error while calling syncKeys: %v", err)
|
||||
}
|
||||
dummyListener.waitForCount(t, 1)
|
||||
|
||||
supportedKeysT2 := map[string]supportedKeyT{
|
||||
"key-1": {
|
||||
@ -396,12 +405,108 @@ func TestReflectChanges(t *testing.T) {
|
||||
backend.supportedKeys = supportedKeysT2
|
||||
backend.keyLock.Unlock()
|
||||
|
||||
if _, err := plugin.keyCache.syncKeys(ctx); err != nil {
|
||||
dummyListener.waitForCount(t, 1)
|
||||
if err := plugin.keyCache.syncKeys(ctx); err != nil {
|
||||
t.Fatalf("Error while calling syncKeys: %v", err)
|
||||
}
|
||||
dummyListener.waitForCount(t, 2)
|
||||
|
||||
gotPubKeysT2 := plugin.keyCache.GetPublicKeys(ctx, "")
|
||||
if diff := cmp.Diff(gotPubKeysT2, wantPubKeysT2, cmpopts.SortSlices(sortPublicKeySlice)); diff != "" {
|
||||
t.Fatalf("Bad public keys; diff (-got +want)\n%s", diff)
|
||||
}
|
||||
dummyListener.waitForCount(t, 2)
|
||||
}
|
||||
|
||||
type dummyListener struct {
|
||||
count atomic.Int64
|
||||
}
|
||||
|
||||
func (d *dummyListener) waitForCount(t *testing.T, expect int) {
|
||||
t.Helper()
|
||||
err := wait.PollUntilContextTimeout(context.Background(), time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) {
|
||||
actual := int(d.count.Load())
|
||||
switch {
|
||||
case actual > expect:
|
||||
return false, fmt.Errorf("expected %d broadcasts, got %d broadcasts", expect, actual)
|
||||
case actual == expect:
|
||||
return true, nil
|
||||
default:
|
||||
t.Logf("expected %d broadcasts, got %d broadcasts, waiting...", expect, actual)
|
||||
return false, nil
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dummyListener) Enqueue() {
|
||||
d.count.Add(1)
|
||||
}
|
||||
|
||||
func TestKeysChanged(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
oldKeys VerificationKeys
|
||||
newKeys VerificationKeys
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
oldKeys: VerificationKeys{},
|
||||
newKeys: VerificationKeys{},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "identical",
|
||||
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
|
||||
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "changed datatimestamp",
|
||||
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
|
||||
newKeys: VerificationKeys{DataTimestamp: time.Unix(1001, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "reordered keyid",
|
||||
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
|
||||
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}, {KeyID: "a"}}},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "changed keyid",
|
||||
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
|
||||
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "b"}}},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "added key",
|
||||
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
|
||||
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "removed key",
|
||||
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}, {KeyID: "b"}}},
|
||||
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a"}}},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "changed oidc",
|
||||
oldKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: false}}},
|
||||
newKeys: VerificationKeys{DataTimestamp: time.Unix(1000, 0), Keys: []serviceaccount.PublicKey{{KeyID: "a", ExcludeFromOIDCDiscovery: true}}},
|
||||
expect: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := keysChanged(&tc.oldKeys, &tc.newKeys)
|
||||
if result != tc.expect {
|
||||
t.Errorf("got %v, expected %v", result, tc.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -258,7 +258,7 @@ func TestExternalTokenGenerator(t *testing.T) {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
sockname := fmt.Sprintf("@test-external-token-generator-%d.sock", i)
|
||||
sockname := fmt.Sprintf("@test-external-token-generator-%d-%d.sock", time.Now().Nanosecond(), i)
|
||||
t.Cleanup(func() { _ = os.Remove(sockname) })
|
||||
|
||||
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
|
||||
|
@ -59,7 +59,7 @@ func TestExternalJWTSigningAndAuth(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
// create and start mock signer.
|
||||
socketPath := "@mock-external-jwt-signer.sock"
|
||||
socketPath := fmt.Sprintf("@mock-external-jwt-signer-%d.sock", time.Now().Nanosecond())
|
||||
t.Cleanup(func() { _ = os.Remove(socketPath) })
|
||||
mockSigner := v1alpha1testing.NewMockSigner(t, socketPath)
|
||||
defer mockSigner.CleanUp()
|
||||
@ -227,7 +227,7 @@ func TestExternalJWTSigningAndAuth(t *testing.T) {
|
||||
}
|
||||
|
||||
if !tokenReviewResult.Status.Authenticated && tc.shouldPassAuth {
|
||||
t.Fatal("Expected Authentication to succeed")
|
||||
t.Fatalf("Expected Authentication to succeed, got %v", tokenReviewResult.Status.Error)
|
||||
} else if tokenReviewResult.Status.Authenticated && !tc.shouldPassAuth {
|
||||
t.Fatal("Expected Authentication to fail")
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user