shared authenticator lookups

This commit is contained in:
Mike Danese 2019-11-11 22:15:40 -08:00
parent 009c731a88
commit 8647e75cec
3 changed files with 172 additions and 11 deletions

View File

@ -18,6 +18,7 @@ go_test(
"//staging/src/k8s.io/apimachinery/pkg/util/uuid:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/authentication/user:go_default_library",
"//vendor/github.com/google/go-cmp/cmp:go_default_library",
"//vendor/github.com/google/uuid:go_default_library",
],
)
@ -32,9 +33,12 @@ go_library(
importmap = "k8s.io/kubernetes/vendor/k8s.io/apiserver/pkg/authentication/token/cache",
importpath = "k8s.io/apiserver/pkg/authentication/token/cache",
deps = [
"//staging/src/k8s.io/apimachinery/pkg/api/errors:go_default_library",
"//staging/src/k8s.io/apimachinery/pkg/util/cache:go_default_library",
"//staging/src/k8s.io/apimachinery/pkg/util/clock:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library",
"//vendor/golang.org/x/sync/singleflight:go_default_library",
"//vendor/k8s.io/klog:go_default_library",
],
)

View File

@ -22,16 +22,26 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
"hash"
"io"
"runtime"
"sync"
"time"
"unsafe"
"golang.org/x/sync/singleflight"
apierrors "k8s.io/apimachinery/pkg/api/errors"
utilclock "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/klog"
)
var errAuthnCrash = apierrors.NewInternalError(errors.New("authentication failed unexpectedly"))
const sharedLookupTimeout = 30 * time.Second
// cacheRecord holds the three return values of the authenticator.Token AuthenticateToken method
type cacheRecord struct {
resp *authenticator.Response
@ -47,6 +57,7 @@ type cachedTokenAuthenticator struct {
failureTTL time.Duration
cache cache
group singleflight.Group
// hashPool is a per authenticator pool of hash.Hash (to avoid allocations from building the Hash)
// HMAC with SHA-256 and a random key is used to prevent precomputation and length extension attacks
@ -98,26 +109,71 @@ func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL,
// AuthenticateToken implements authenticator.Token
func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) {
auds, _ := authenticator.AudiencesFrom(ctx)
auds, audsOk := authenticator.AudiencesFrom(ctx)
key := keyFunc(a.hashPool, auds, token)
if record, ok := a.cache.get(key); ok {
return record.resp, record.ok, record.err
}
resp, ok, err := a.authenticator.AuthenticateToken(ctx, token)
if !a.cacheErrs && err != nil {
return resp, ok, err
type lookup struct {
resp *authenticator.Response
ok bool
}
switch {
case ok && a.successTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL)
case !ok && a.failureTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL)
}
c := a.group.DoChan(key, func() (val interface{}, err error) {
// We're leaving the request handling stack so we need to handle crashes
// ourselves. Log a stack trace and return a 500 if something panics.
defer func() {
if r := recover(); r != nil {
err = errAuthnCrash
// Same as stdlib http server code. Manually allocate stack
// trace buffer size to prevent excessively large logs
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
klog.Errorf("%v\n%s", r, buf)
}
}()
return resp, ok, err
// Check again for a cached record. We may have raced with a fetch.
if record, ok := a.cache.get(key); ok {
return lookup{record.resp, record.ok}, record.err
}
// Detach the context because the lookup may be shared by multiple callers,
// however propagate the audience.
ctx, cancel := context.WithTimeout(context.Background(), sharedLookupTimeout)
defer cancel()
if audsOk {
ctx = authenticator.WithAudiences(ctx, auds)
}
resp, ok, err := a.authenticator.AuthenticateToken(ctx, token)
if !a.cacheErrs && err != nil {
return nil, err
}
switch {
case ok && a.successTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL)
case !ok && a.failureTTL > 0:
a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL)
}
return lookup{resp, ok}, err
})
select {
case result := <-c:
if result.Err != nil {
return nil, false, result.Err
}
lookup := result.Val.(lookup)
return lookup.resp, lookup.ok, nil
case <-ctx.Done():
return nil, false, ctx.Err()
}
}
// keyFunc generates a string key by hashing the inputs.

View File

@ -30,6 +30,7 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
utilclock "k8s.io/apimachinery/pkg/util/clock"
"k8s.io/apimachinery/pkg/util/uuid"
"k8s.io/apiserver/pkg/authentication/authenticator"
@ -173,6 +174,106 @@ func BenchmarkKeyFunc(b *testing.B) {
})
}
func TestSharedLookup(t *testing.T) {
var chewie = &authenticator.Response{User: &user.DefaultInfo{Name: "chewbacca"}}
t.Run("actually shared", func(t *testing.T) {
var lookups uint32
c := make(chan struct{})
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
<-c
atomic.AddUint32(&lookups, 1)
return chewie, true, nil
}), true, time.Minute, 0)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
a.AuthenticateToken(context.Background(), "")
}()
}
// no good way to make sure that all the callers are queued so we sleep.
time.Sleep(1 * time.Second)
close(c)
wg.Wait()
if lookups > 3 {
t.Fatalf("unexpected number of lookups: got=%d, wanted less than 3", lookups)
}
})
t.Run("first caller bails, second caller gets result", func(t *testing.T) {
c := make(chan struct{})
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
<-c
return chewie, true, nil
}), true, time.Minute, 0)
var wg sync.WaitGroup
wg.Add(2)
ctx1, cancel1 := context.WithCancel(context.Background())
go func() {
defer wg.Done()
a.AuthenticateToken(ctx1, "")
}()
ctx2 := context.Background()
var (
resp *authenticator.Response
ok bool
err error
)
go func() {
defer wg.Done()
resp, ok, err = a.AuthenticateToken(ctx2, "")
}()
time.Sleep(1 * time.Second)
cancel1()
close(c)
wg.Wait()
if want := chewie; !cmp.Equal(resp, want) {
t.Errorf("Unexpected diff: %v", cmp.Diff(resp, want))
}
if !ok {
t.Errorf("Expected ok response")
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
})
t.Run("lookup panics", func(t *testing.T) {
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
panic("uh oh")
}), true, time.Minute, 0)
_, _, err := a.AuthenticateToken(context.Background(), "")
if err != errAuthnCrash {
t.Errorf("expected error: %v", err)
}
})
t.Run("audiences are forwarded", func(t *testing.T) {
ctx := authenticator.WithAudiences(context.Background(), []string{"a"})
a := New(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
auds, _ := authenticator.AudiencesFrom(ctx)
if got, want := auds, []string{"a"}; cmp.Equal(got, want) {
t.Fatalf("unexpeced audiences: %v", cmp.Diff(got, want))
}
return nil, false, nil
}), true, time.Minute, 0)
a.AuthenticateToken(ctx, "")
})
}
func BenchmarkCachedTokenAuthenticator(b *testing.B) {
tokenCount := []int{100, 500, 2500, 12500, 62500}
threadCount := []int{1, 16, 256}