diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD index 1f08a3a050f..771634d747a 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD @@ -18,6 +18,7 @@ go_test( "//staging/src/k8s.io/apimachinery/pkg/util/uuid:go_default_library", "//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library", "//staging/src/k8s.io/apiserver/pkg/authentication/user:go_default_library", + "//vendor/github.com/google/go-cmp/cmp:go_default_library", "//vendor/github.com/google/uuid:go_default_library", ], ) @@ -32,9 +33,12 @@ go_library( importmap = "k8s.io/kubernetes/vendor/k8s.io/apiserver/pkg/authentication/token/cache", importpath = "k8s.io/apiserver/pkg/authentication/token/cache", deps = [ + "//staging/src/k8s.io/apimachinery/pkg/api/errors:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/cache:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/clock:go_default_library", "//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library", + "//vendor/golang.org/x/sync/singleflight:go_default_library", + "//vendor/k8s.io/klog:go_default_library", ], ) diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go index ef0a8c87215..66a9fc577f5 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go @@ -22,16 +22,26 @@ import ( "crypto/rand" "crypto/sha256" "encoding/binary" + "errors" "hash" "io" + "runtime" "sync" "time" "unsafe" + "golang.org/x/sync/singleflight" + + apierrors "k8s.io/apimachinery/pkg/api/errors" utilclock "k8s.io/apimachinery/pkg/util/clock" "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/klog" ) +var errAuthnCrash = apierrors.NewInternalError(errors.New("authentication failed unexpectedly")) + +const sharedLookupTimeout = 30 * time.Second + // cacheRecord holds the three return values of the authenticator.Token AuthenticateToken method type cacheRecord struct { resp *authenticator.Response @@ -47,6 +57,7 @@ type cachedTokenAuthenticator struct { failureTTL time.Duration cache cache + group singleflight.Group // hashPool is a per authenticator pool of hash.Hash (to avoid allocations from building the Hash) // HMAC with SHA-256 and a random key is used to prevent precomputation and length extension attacks @@ -98,26 +109,71 @@ func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL, // AuthenticateToken implements authenticator.Token func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) { - auds, _ := authenticator.AudiencesFrom(ctx) + auds, audsOk := authenticator.AudiencesFrom(ctx) key := keyFunc(a.hashPool, auds, token) if record, ok := a.cache.get(key); ok { return record.resp, record.ok, record.err } - resp, ok, err := a.authenticator.AuthenticateToken(ctx, token) - if !a.cacheErrs && err != nil { - return resp, ok, err + type lookup struct { + resp *authenticator.Response + ok bool } - switch { - case ok && a.successTTL > 0: - a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL) - case !ok && a.failureTTL > 0: - a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL) - } + c := a.group.DoChan(key, func() (val interface{}, err error) { + // We're leaving the request handling stack so we need to handle crashes + // ourselves. Log a stack trace and return a 500 if something panics. + defer func() { + if r := recover(); r != nil { + err = errAuthnCrash + // Same as stdlib http server code. Manually allocate stack + // trace buffer size to prevent excessively large logs + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + klog.Errorf("%v\n%s", r, buf) + } + }() - return resp, ok, err + // Check again for a cached record. We may have raced with a fetch. + if record, ok := a.cache.get(key); ok { + return lookup{record.resp, record.ok}, record.err + } + + // Detach the context because the lookup may be shared by multiple callers, + // however propagate the audience. + ctx, cancel := context.WithTimeout(context.Background(), sharedLookupTimeout) + defer cancel() + + if audsOk { + ctx = authenticator.WithAudiences(ctx, auds) + } + + resp, ok, err := a.authenticator.AuthenticateToken(ctx, token) + if !a.cacheErrs && err != nil { + return nil, err + } + + switch { + case ok && a.successTTL > 0: + a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL) + case !ok && a.failureTTL > 0: + a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL) + } + return lookup{resp, ok}, err + }) + + select { + case result := <-c: + if result.Err != nil { + return nil, false, result.Err + } + lookup := result.Val.(lookup) + return lookup.resp, lookup.ok, nil + case <-ctx.Done(): + return nil, false, ctx.Err() + } } // keyFunc generates a string key by hashing the inputs. diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go index c6fb207e3a7..7252a0a24d4 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go @@ -30,6 +30,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" utilclock "k8s.io/apimachinery/pkg/util/clock" "k8s.io/apimachinery/pkg/util/uuid" "k8s.io/apiserver/pkg/authentication/authenticator" @@ -173,6 +174,106 @@ func BenchmarkKeyFunc(b *testing.B) { }) } +func TestSharedLookup(t *testing.T) { + var chewie = &authenticator.Response{User: &user.DefaultInfo{Name: "chewbacca"}} + + t.Run("actually shared", func(t *testing.T) { + var lookups uint32 + c := make(chan struct{}) + a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) { + <-c + atomic.AddUint32(&lookups, 1) + return chewie, true, nil + }), true, time.Minute, 0) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + a.AuthenticateToken(context.Background(), "") + }() + } + + // no good way to make sure that all the callers are queued so we sleep. + time.Sleep(1 * time.Second) + close(c) + wg.Wait() + + if lookups > 3 { + t.Fatalf("unexpected number of lookups: got=%d, wanted less than 3", lookups) + } + }) + + t.Run("first caller bails, second caller gets result", func(t *testing.T) { + c := make(chan struct{}) + a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) { + <-c + return chewie, true, nil + }), true, time.Minute, 0) + + var wg sync.WaitGroup + wg.Add(2) + + ctx1, cancel1 := context.WithCancel(context.Background()) + go func() { + defer wg.Done() + a.AuthenticateToken(ctx1, "") + }() + + ctx2 := context.Background() + + var ( + resp *authenticator.Response + ok bool + err error + ) + go func() { + defer wg.Done() + resp, ok, err = a.AuthenticateToken(ctx2, "") + }() + + time.Sleep(1 * time.Second) + cancel1() + close(c) + wg.Wait() + + if want := chewie; !cmp.Equal(resp, want) { + t.Errorf("Unexpected diff: %v", cmp.Diff(resp, want)) + } + if !ok { + t.Errorf("Expected ok response") + } + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + + t.Run("lookup panics", func(t *testing.T) { + a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) { + panic("uh oh") + }), true, time.Minute, 0) + + _, _, err := a.AuthenticateToken(context.Background(), "") + if err != errAuthnCrash { + t.Errorf("expected error: %v", err) + } + }) + + t.Run("audiences are forwarded", func(t *testing.T) { + ctx := authenticator.WithAudiences(context.Background(), []string{"a"}) + a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) { + auds, _ := authenticator.AudiencesFrom(ctx) + if got, want := auds, []string{"a"}; cmp.Equal(got, want) { + t.Fatalf("unexpeced audiences: %v", cmp.Diff(got, want)) + } + return nil, false, nil + }), true, time.Minute, 0) + + a.AuthenticateToken(ctx, "") + }) +} + func BenchmarkCachedTokenAuthenticator(b *testing.B) { tokenCount := []int{100, 500, 2500, 12500, 62500} threadCount := []int{1, 16, 256}