mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-27 05:27:21 +00:00
Merge pull request #38167 from ericchiang/oidc-client-auth-cache-provider
Automatic merge from submit-queue (batch tested with PRs 39648, 38167, 39591, 39415, 39612) oidc client auth provider: cache OpenID Connect clients to prevent reinitialization Still need to add tests. closes #37876 cc @kubernetes/sig-auth @liggitt @jsloyer @mlbiam @philips
This commit is contained in:
commit
17665a009f
@ -14,7 +14,6 @@ go_library(
|
|||||||
tags = ["automanaged"],
|
tags = ["automanaged"],
|
||||||
deps = [
|
deps = [
|
||||||
"//pkg/client/restclient:go_default_library",
|
"//pkg/client/restclient:go_default_library",
|
||||||
"//pkg/util/wait:go_default_library",
|
|
||||||
"//vendor:github.com/coreos/go-oidc/jose",
|
"//vendor:github.com/coreos/go-oidc/jose",
|
||||||
"//vendor:github.com/coreos/go-oidc/oauth2",
|
"//vendor:github.com/coreos/go-oidc/oauth2",
|
||||||
"//vendor:github.com/coreos/go-oidc/oidc",
|
"//vendor:github.com/coreos/go-oidc/oidc",
|
||||||
@ -28,8 +27,6 @@ go_test(
|
|||||||
library = ":go_default_library",
|
library = ":go_default_library",
|
||||||
tags = ["automanaged"],
|
tags = ["automanaged"],
|
||||||
deps = [
|
deps = [
|
||||||
"//pkg/util/diff:go_default_library",
|
|
||||||
"//pkg/util/wait:go_default_library",
|
|
||||||
"//plugin/pkg/auth/authenticator/token/oidc/testing:go_default_library",
|
"//plugin/pkg/auth/authenticator/token/oidc/testing:go_default_library",
|
||||||
"//vendor:github.com/coreos/go-oidc/jose",
|
"//vendor:github.com/coreos/go-oidc/jose",
|
||||||
"//vendor:github.com/coreos/go-oidc/key",
|
"//vendor:github.com/coreos/go-oidc/key",
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/jose"
|
"github.com/coreos/go-oidc/jose"
|
||||||
@ -30,7 +31,6 @@ import (
|
|||||||
"github.com/golang/glog"
|
"github.com/golang/glog"
|
||||||
|
|
||||||
"k8s.io/kubernetes/pkg/client/restclient"
|
"k8s.io/kubernetes/pkg/client/restclient"
|
||||||
"k8s.io/kubernetes/pkg/util/wait"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -44,21 +44,68 @@ const (
|
|||||||
cfgRefreshToken = "refresh-token"
|
cfgRefreshToken = "refresh-token"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
backoff = wait.Backoff{
|
|
||||||
Duration: 1 * time.Second,
|
|
||||||
Factor: 2,
|
|
||||||
Jitter: .1,
|
|
||||||
Steps: 5,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil {
|
if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil {
|
||||||
glog.Fatalf("Failed to register oidc auth plugin: %v", err)
|
glog.Fatalf("Failed to register oidc auth plugin: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expiryDelta determines how earlier a token should be considered
|
||||||
|
// expired than its actual expiration time. It is used to avoid late
|
||||||
|
// expirations due to client-server time mismatches.
|
||||||
|
//
|
||||||
|
// NOTE(ericchiang): this is take from golang.org/x/oauth2
|
||||||
|
const expiryDelta = 10 * time.Second
|
||||||
|
|
||||||
|
var cache = newClientCache()
|
||||||
|
|
||||||
|
// Like TLS transports, keep a cache of OIDC clients indexed by issuer URL.
|
||||||
|
type clientCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
cache map[cacheKey]*oidcAuthProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientCache() *clientCache {
|
||||||
|
return &clientCache{cache: make(map[cacheKey]*oidcAuthProvider)}
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheKey struct {
|
||||||
|
// Canonical issuer URL string of the provider.
|
||||||
|
issuerURL string
|
||||||
|
|
||||||
|
clientID string
|
||||||
|
clientSecret string
|
||||||
|
|
||||||
|
// Don't use CA as cache key because we only add a cache entry if we can connect
|
||||||
|
// to the issuer in the first place. A valid CA is a prerequisite.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientCache) getClient(issuer, clientID, clientSecret string) (*oidcAuthProvider, bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
client, ok := c.cache[cacheKey{issuer, clientID, clientSecret}]
|
||||||
|
return client, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// setClient attempts to put the client in the cache but may return any clients
|
||||||
|
// with the same keys set before. This is so there's only ever one client for a provider.
|
||||||
|
func (c *clientCache) setClient(issuer, clientID, clientSecret string, client *oidcAuthProvider) *oidcAuthProvider {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
key := cacheKey{issuer, clientID, clientSecret}
|
||||||
|
|
||||||
|
// If another client has already initialized a client for the given provider we want
|
||||||
|
// to use that client instead of the one we're trying to set. This is so all transports
|
||||||
|
// share a client and can coordinate around the same mutex when refreshing and writing
|
||||||
|
// to the kubeconfig.
|
||||||
|
if oldClient, ok := c.cache[key]; ok {
|
||||||
|
return oldClient
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cache[key] = client
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
|
func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
|
||||||
issuer := cfg[cfgIssuerUrl]
|
issuer := cfg[cfgIssuerUrl]
|
||||||
if issuer == "" {
|
if issuer == "" {
|
||||||
@ -75,6 +122,11 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
|
|||||||
return nil, fmt.Errorf("Must provide %s", cfgClientSecret)
|
return nil, fmt.Errorf("Must provide %s", cfgClientSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check cache for existing provider.
|
||||||
|
if provider, ok := cache.getClient(issuer, clientID, clientSecret); ok {
|
||||||
|
return provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
var certAuthData []byte
|
var certAuthData []byte
|
||||||
var err error
|
var err error
|
||||||
if cfg[cfgCertificateAuthorityData] != "" {
|
if cfg[cfgCertificateAuthorityData] != "" {
|
||||||
@ -112,146 +164,134 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
|
|||||||
ProviderConfig: providerCfg,
|
ProviderConfig: providerCfg,
|
||||||
Scope: append(scopes, oidc.DefaultScope...),
|
Scope: append(scopes, oidc.DefaultScope...),
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := oidc.NewClient(oidcCfg)
|
client, err := oidc.NewClient(oidcCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating OIDC Client: %v", err)
|
return nil, fmt.Errorf("error creating OIDC Client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oClient := &oidcClient{client}
|
provider := &oidcAuthProvider{
|
||||||
|
client: &oidcClient{client},
|
||||||
var initialIDToken jose.JWT
|
|
||||||
if cfg[cfgIDToken] != "" {
|
|
||||||
initialIDToken, err = jose.ParseJWT(cfg[cfgIDToken])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &oidcAuthProvider{
|
|
||||||
initialIDToken: initialIDToken,
|
|
||||||
refresher: &idTokenRefresher{
|
|
||||||
client: oClient,
|
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
persister: persister,
|
persister: persister,
|
||||||
},
|
now: time.Now,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
return cache.setClient(issuer, clientID, clientSecret, provider), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type oidcAuthProvider struct {
|
type oidcAuthProvider struct {
|
||||||
refresher *idTokenRefresher
|
// Interface rather than a raw *oidc.Client for testing.
|
||||||
initialIDToken jose.JWT
|
client OIDCClient
|
||||||
|
|
||||||
|
// Stubbed out for testing.
|
||||||
|
now func() time.Time
|
||||||
|
|
||||||
|
// Mutex guards persisting to the kubeconfig file and allows synchronized
|
||||||
|
// updates to the in-memory config. It also ensures concurrent calls to
|
||||||
|
// the RoundTripper only trigger a single refresh request.
|
||||||
|
mu sync.Mutex
|
||||||
|
cfg map[string]string
|
||||||
|
persister restclient.AuthProviderConfigPersister
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
|
func (p *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
|
||||||
at := &oidc.AuthenticatedTransport{
|
|
||||||
TokenRefresher: g.refresher,
|
|
||||||
RoundTripper: rt,
|
|
||||||
}
|
|
||||||
at.SetJWT(g.initialIDToken)
|
|
||||||
return &roundTripper{
|
return &roundTripper{
|
||||||
wrapped: at,
|
wrapped: rt,
|
||||||
refresher: g.refresher,
|
provider: p,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *oidcAuthProvider) Login() error {
|
func (p *oidcAuthProvider) Login() error {
|
||||||
return errors.New("not yet implemented")
|
return errors.New("not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCClient interface {
|
type OIDCClient interface {
|
||||||
refreshToken(rt string) (oauth2.TokenResponse, error)
|
refreshToken(rt string) (oauth2.TokenResponse, error)
|
||||||
verifyJWT(jwt jose.JWT) error
|
verifyJWT(jwt *jose.JWT) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type roundTripper struct {
|
type roundTripper struct {
|
||||||
refresher *idTokenRefresher
|
provider *oidcAuthProvider
|
||||||
wrapped *oidc.AuthenticatedTransport
|
wrapped http.RoundTripper
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
var res *http.Response
|
token, err := r.provider.idToken()
|
||||||
var err error
|
|
||||||
firstTime := true
|
|
||||||
wait.ExponentialBackoff(backoff, func() (bool, error) {
|
|
||||||
if !firstTime {
|
|
||||||
var jwt jose.JWT
|
|
||||||
jwt, err = r.refresher.Refresh()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, nil
|
return nil, err
|
||||||
}
|
|
||||||
r.wrapped.SetJWT(jwt)
|
|
||||||
} else {
|
|
||||||
firstTime = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err = r.wrapped.RoundTrip(req)
|
// shallow copy of the struct
|
||||||
|
r2 := new(http.Request)
|
||||||
|
*r2 = *req
|
||||||
|
// deep copy of the Header so we don't modify the original
|
||||||
|
// request's Header (as per RoundTripper contract).
|
||||||
|
r2.Header = make(http.Header)
|
||||||
|
for k, s := range req.Header {
|
||||||
|
r2.Header[k] = s
|
||||||
|
}
|
||||||
|
r2.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
|
|
||||||
|
return r.wrapped.RoundTrip(r2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *oidcAuthProvider) idToken() (string, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if idToken, ok := p.cfg[cfgIDToken]; ok && len(idToken) > 0 {
|
||||||
|
valid, err := verifyJWTExpiry(p.now(), idToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, nil
|
return "", err
|
||||||
}
|
}
|
||||||
if res.StatusCode == http.StatusUnauthorized {
|
if valid {
|
||||||
return false, nil
|
// If the cached id token is still valid use it.
|
||||||
|
return idToken, nil
|
||||||
}
|
}
|
||||||
return true, nil
|
|
||||||
})
|
|
||||||
return res, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type idTokenRefresher struct {
|
// Try to request a new token using the refresh token.
|
||||||
cfg map[string]string
|
rt, ok := p.cfg[cfgRefreshToken]
|
||||||
client OIDCClient
|
if !ok || len(rt) == 0 {
|
||||||
persister restclient.AuthProviderConfigPersister
|
return "", errors.New("No valid id-token, and cannot refresh without refresh-token")
|
||||||
intialIDToken jose.JWT
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *idTokenRefresher) Verify(jwt jose.JWT) error {
|
tokens, err := p.client.refreshToken(rt)
|
||||||
claims, err := jwt.Claims()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", fmt.Errorf("could not refresh token: %v", err)
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
exp, ok, err := claims.TimeClaim("exp")
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
return fmt.Errorf("failed to parse 'exp' claim: %v", err)
|
|
||||||
case !ok:
|
|
||||||
return errors.New("missing required 'exp' claim")
|
|
||||||
case exp.Before(now):
|
|
||||||
return fmt.Errorf("token already expired at: %v", exp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *idTokenRefresher) Refresh() (jose.JWT, error) {
|
|
||||||
rt, ok := r.cfg[cfgRefreshToken]
|
|
||||||
if !ok {
|
|
||||||
return jose.JWT{}, errors.New("No valid id-token, and cannot refresh without refresh-token")
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens, err := r.client.refreshToken(rt)
|
|
||||||
if err != nil {
|
|
||||||
return jose.JWT{}, fmt.Errorf("could not refresh token: %v", err)
|
|
||||||
}
|
}
|
||||||
jwt, err := jose.ParseJWT(tokens.IDToken)
|
jwt, err := jose.ParseJWT(tokens.IDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return jose.JWT{}, err
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.client.verifyJWT(&jwt); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new config to persist.
|
||||||
|
newCfg := make(map[string]string)
|
||||||
|
for key, val := range p.cfg {
|
||||||
|
newCfg[key] = val
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokens.RefreshToken != "" && tokens.RefreshToken != rt {
|
if tokens.RefreshToken != "" && tokens.RefreshToken != rt {
|
||||||
r.cfg[cfgRefreshToken] = tokens.RefreshToken
|
newCfg[cfgRefreshToken] = tokens.RefreshToken
|
||||||
}
|
|
||||||
r.cfg[cfgIDToken] = jwt.Encode()
|
|
||||||
|
|
||||||
err = r.persister.Persist(r.cfg)
|
|
||||||
if err != nil {
|
|
||||||
return jose.JWT{}, fmt.Errorf("could not perist new tokens: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return jwt, r.client.verifyJWT(jwt)
|
newCfg[cfgIDToken] = tokens.IDToken
|
||||||
|
if err = p.persister.Persist(newCfg); err != nil {
|
||||||
|
return "", fmt.Errorf("could not perist new tokens: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update the in memory config to reflect the on disk one.
|
||||||
|
p.cfg = newCfg
|
||||||
|
|
||||||
|
return tokens.IDToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// oidcClient is the real implementation of the OIDCClient interface, which is
|
||||||
|
// used for testing.
|
||||||
type oidcClient struct {
|
type oidcClient struct {
|
||||||
client *oidc.Client
|
client *oidc.Client
|
||||||
}
|
}
|
||||||
@ -265,6 +305,29 @@ func (o *oidcClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
|||||||
return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt)
|
return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *oidcClient) verifyJWT(jwt jose.JWT) error {
|
func (o *oidcClient) verifyJWT(jwt *jose.JWT) error {
|
||||||
return o.client.VerifyJWT(jwt)
|
return o.client.VerifyJWT(*jwt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyJWTExpiry(now time.Time, s string) (valid bool, err error) {
|
||||||
|
jwt, err := jose.ParseJWT(s)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("invalid %q", cfgIDToken)
|
||||||
|
}
|
||||||
|
claims, err := jwt.Claims()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
exp, ok, err := claims.TimeClaim("exp")
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return false, fmt.Errorf("failed to parse 'exp' claim: %v", err)
|
||||||
|
case !ok:
|
||||||
|
return false, errors.New("missing required 'exp' claim")
|
||||||
|
case exp.After(now.Add(expiryDelta)):
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -19,13 +19,10 @@ package oidc
|
|||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -33,11 +30,41 @@ import (
|
|||||||
"github.com/coreos/go-oidc/key"
|
"github.com/coreos/go-oidc/key"
|
||||||
"github.com/coreos/go-oidc/oauth2"
|
"github.com/coreos/go-oidc/oauth2"
|
||||||
|
|
||||||
"k8s.io/kubernetes/pkg/util/diff"
|
|
||||||
"k8s.io/kubernetes/pkg/util/wait"
|
|
||||||
oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
|
oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func clearCache() {
|
||||||
|
cache = newClientCache()
|
||||||
|
}
|
||||||
|
|
||||||
|
type persister struct{}
|
||||||
|
|
||||||
|
// we don't need to actually persist anything because there's no way for us to
|
||||||
|
// read from a persister.
|
||||||
|
func (p *persister) Persist(map[string]string) error { return nil }
|
||||||
|
|
||||||
|
type noRefreshOIDCClient struct{}
|
||||||
|
|
||||||
|
func (c *noRefreshOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||||
|
return oauth2.TokenResponse{}, errors.New("alwaysErrOIDCClient: cannot refresh token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *noRefreshOIDCClient) verifyJWT(jwt *jose.JWT) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockOIDCClient struct {
|
||||||
|
tokenResponse oauth2.TokenResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||||
|
return c.tokenResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockOIDCClient) verifyJWT(jwt *jose.JWT) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewOIDCAuthProvider(t *testing.T) {
|
func TestNewOIDCAuthProvider(t *testing.T) {
|
||||||
tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test")
|
tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -60,127 +87,211 @@ func TestNewOIDCAuthProvider(t *testing.T) {
|
|||||||
t.Fatalf("Could not read cert bytes %v", err)
|
t.Fatalf("Could not read cert bytes %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
makeToken := func(exp time.Time) *jose.JWT {
|
||||||
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||||
"test": "jwt",
|
"exp": exp.UTC().Unix(),
|
||||||
}), op.PrivKey.Signer())
|
}), op.PrivKey.Signer())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Could not create signed JWT %v", err)
|
t.Fatalf("Could not create signed JWT %v", err)
|
||||||
}
|
}
|
||||||
|
return jwt
|
||||||
|
}
|
||||||
|
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
goodToken := makeToken(t0.Add(time.Hour)).Encode()
|
||||||
|
expiredToken := makeToken(t0.Add(-time.Hour)).Encode()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
cfg map[string]string
|
name string
|
||||||
|
|
||||||
wantErr bool
|
cfg map[string]string
|
||||||
wantInitialIDToken jose.JWT
|
wantInitErr bool
|
||||||
|
|
||||||
|
client OIDCClient
|
||||||
|
wantCfg map[string]string
|
||||||
|
wantTokenErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
// A Valid configuration
|
// A Valid configuration
|
||||||
|
name: "no id token and no refresh token",
|
||||||
cfg: map[string]string{
|
cfg: map[string]string{
|
||||||
cfgIssuerUrl: srv.URL,
|
cfgIssuerUrl: srv.URL,
|
||||||
cfgCertificateAuthority: cert,
|
cfgCertificateAuthority: cert,
|
||||||
cfgClientID: "client-id",
|
cfgClientID: "client-id",
|
||||||
cfgClientSecret: "client-secret",
|
cfgClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
|
wantTokenErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// A Valid configuration with an Initial JWT
|
name: "valid config with an initial token",
|
||||||
cfg: map[string]string{
|
cfg: map[string]string{
|
||||||
cfgIssuerUrl: srv.URL,
|
cfgIssuerUrl: srv.URL,
|
||||||
cfgCertificateAuthority: cert,
|
cfgCertificateAuthority: cert,
|
||||||
cfgClientID: "client-id",
|
cfgClientID: "client-id",
|
||||||
cfgClientSecret: "client-secret",
|
cfgClientSecret: "client-secret",
|
||||||
cfgIDToken: jwt.Encode(),
|
cfgIDToken: goodToken,
|
||||||
|
},
|
||||||
|
client: new(noRefreshOIDCClient),
|
||||||
|
wantCfg: map[string]string{
|
||||||
|
cfgIssuerUrl: srv.URL,
|
||||||
|
cfgCertificateAuthority: cert,
|
||||||
|
cfgClientID: "client-id",
|
||||||
|
cfgClientSecret: "client-secret",
|
||||||
|
cfgIDToken: goodToken,
|
||||||
},
|
},
|
||||||
wantInitialIDToken: *jwt,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Valid config, but using cfgCertificateAuthorityData
|
name: "invalid ID token with a refresh token",
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIssuerUrl: srv.URL,
|
||||||
|
cfgCertificateAuthority: cert,
|
||||||
|
cfgClientID: "client-id",
|
||||||
|
cfgClientSecret: "client-secret",
|
||||||
|
cfgRefreshToken: "foo",
|
||||||
|
cfgIDToken: expiredToken,
|
||||||
|
},
|
||||||
|
client: &mockOIDCClient{
|
||||||
|
tokenResponse: oauth2.TokenResponse{
|
||||||
|
IDToken: goodToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantCfg: map[string]string{
|
||||||
|
cfgIssuerUrl: srv.URL,
|
||||||
|
cfgCertificateAuthority: cert,
|
||||||
|
cfgClientID: "client-id",
|
||||||
|
cfgClientSecret: "client-secret",
|
||||||
|
cfgRefreshToken: "foo",
|
||||||
|
cfgIDToken: goodToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid ID token with a refresh token, server returns new refresh token",
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIssuerUrl: srv.URL,
|
||||||
|
cfgCertificateAuthority: cert,
|
||||||
|
cfgClientID: "client-id",
|
||||||
|
cfgClientSecret: "client-secret",
|
||||||
|
cfgRefreshToken: "foo",
|
||||||
|
cfgIDToken: expiredToken,
|
||||||
|
},
|
||||||
|
client: &mockOIDCClient{
|
||||||
|
tokenResponse: oauth2.TokenResponse{
|
||||||
|
IDToken: goodToken,
|
||||||
|
RefreshToken: "bar",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantCfg: map[string]string{
|
||||||
|
cfgIssuerUrl: srv.URL,
|
||||||
|
cfgCertificateAuthority: cert,
|
||||||
|
cfgClientID: "client-id",
|
||||||
|
cfgClientSecret: "client-secret",
|
||||||
|
cfgRefreshToken: "bar",
|
||||||
|
cfgIDToken: goodToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired token and no refresh otken",
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIssuerUrl: srv.URL,
|
||||||
|
cfgCertificateAuthority: cert,
|
||||||
|
cfgClientID: "client-id",
|
||||||
|
cfgClientSecret: "client-secret",
|
||||||
|
cfgIDToken: expiredToken,
|
||||||
|
},
|
||||||
|
wantTokenErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid base64d ca",
|
||||||
cfg: map[string]string{
|
cfg: map[string]string{
|
||||||
cfgIssuerUrl: srv.URL,
|
cfgIssuerUrl: srv.URL,
|
||||||
cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData),
|
cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData),
|
||||||
cfgClientID: "client-id",
|
cfgClientID: "client-id",
|
||||||
cfgClientSecret: "client-secret",
|
cfgClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
|
client: new(noRefreshOIDCClient),
|
||||||
|
wantTokenErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Missing client id
|
name: "missing client ID",
|
||||||
cfg: map[string]string{
|
cfg: map[string]string{
|
||||||
cfgIssuerUrl: srv.URL,
|
cfgIssuerUrl: srv.URL,
|
||||||
cfgCertificateAuthority: cert,
|
cfgCertificateAuthority: cert,
|
||||||
cfgClientSecret: "client-secret",
|
cfgClientSecret: "client-secret",
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantInitErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Missing client secret
|
name: "missing client secret",
|
||||||
cfg: map[string]string{
|
cfg: map[string]string{
|
||||||
cfgIssuerUrl: srv.URL,
|
cfgIssuerUrl: srv.URL,
|
||||||
cfgCertificateAuthority: cert,
|
cfgCertificateAuthority: cert,
|
||||||
cfgClientID: "client-id",
|
cfgClientID: "client-id",
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantInitErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Missing issuer url.
|
name: "missing issuer URL",
|
||||||
cfg: map[string]string{
|
cfg: map[string]string{
|
||||||
cfgCertificateAuthority: cert,
|
cfgCertificateAuthority: cert,
|
||||||
cfgClientID: "client-id",
|
cfgClientID: "client-id",
|
||||||
cfgClientSecret: "secret",
|
cfgClientSecret: "secret",
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantInitErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// No TLS config
|
name: "missing TLS config",
|
||||||
cfg: map[string]string{
|
cfg: map[string]string{
|
||||||
cfgIssuerUrl: srv.URL,
|
cfgIssuerUrl: srv.URL,
|
||||||
cfgClientID: "client-id",
|
cfgClientID: "client-id",
|
||||||
cfgClientSecret: "secret",
|
cfgClientSecret: "secret",
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantInitErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for _, tt := range tests {
|
||||||
ap, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, nil)
|
clearCache()
|
||||||
if tt.wantErr {
|
|
||||||
|
p, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, new(persister))
|
||||||
|
if tt.wantInitErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("case %d: want non-nil err", i)
|
t.Errorf("%s: want non-nil err", tt.name)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("case %d: unexpected error on newOIDCAuthProvider: %v", i, err)
|
t.Errorf("%s: unexpected error on newOIDCAuthProvider: %v", tt.name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcAP, ok := ap.(*oidcAuthProvider)
|
provider := p.(*oidcAuthProvider)
|
||||||
if !ok {
|
provider.client = tt.client
|
||||||
t.Errorf("case %d: expected ap to be an oidcAuthProvider", i)
|
provider.now = func() time.Time { return t0 }
|
||||||
|
|
||||||
|
if _, err := provider.idToken(); err != nil {
|
||||||
|
if !tt.wantTokenErr {
|
||||||
|
t.Errorf("%s: failed to get id token: %v", tt.name, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tt.wantTokenErr {
|
||||||
|
t.Errorf("%s: expected to not get id token: %v", tt.name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken); diff != "" {
|
if !reflect.DeepEqual(tt.wantCfg, provider.cfg) {
|
||||||
t.Errorf("case %d: compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken)=%v", i, diff)
|
t.Errorf("%s: expected config %#v got %#v", tt.name, tt.wantCfg, provider.cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrapTranport(t *testing.T) {
|
func TestVerifyJWTExpiry(t *testing.T) {
|
||||||
oldBackoff := backoff
|
|
||||||
defer func() {
|
|
||||||
backoff = oldBackoff
|
|
||||||
}()
|
|
||||||
backoff = wait.Backoff{
|
|
||||||
Duration: 1 * time.Nanosecond,
|
|
||||||
Steps: 3,
|
|
||||||
}
|
|
||||||
|
|
||||||
privKey, err := key.GeneratePrivateKey()
|
privKey, err := key.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("can't generate private key: %v", err)
|
t.Fatalf("can't generate private key: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
makeToken := func(s string, exp time.Time, count int) *jose.JWT {
|
makeToken := func(s string, exp time.Time, count int) *jose.JWT {
|
||||||
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||||
"test": s,
|
"test": s,
|
||||||
@ -193,451 +304,81 @@ func TestWrapTranport(t *testing.T) {
|
|||||||
return jwt
|
return jwt
|
||||||
}
|
}
|
||||||
|
|
||||||
goodToken := makeToken("good", time.Now().Add(time.Hour), 0)
|
t0 := time.Now()
|
||||||
goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1)
|
|
||||||
expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0)
|
|
||||||
|
|
||||||
str := func(s string) *string {
|
|
||||||
return &s
|
|
||||||
}
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
cfgIDToken *jose.JWT
|
name string
|
||||||
cfgRefreshToken *string
|
jwt *jose.JWT
|
||||||
|
now time.Time
|
||||||
expectRequests []testRoundTrip
|
|
||||||
|
|
||||||
expectRefreshes []testRefresh
|
|
||||||
|
|
||||||
expectPersists []testPersist
|
|
||||||
|
|
||||||
wantStatus int
|
|
||||||
wantErr bool
|
wantErr bool
|
||||||
|
wantExpired bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
// Initial JWT is set, it is good, it is set as bearer.
|
name: "valid jwt",
|
||||||
cfgIDToken: goodToken,
|
jwt: makeToken("foo", t0.Add(time.Hour), 1),
|
||||||
|
now: t0,
|
||||||
expectRequests: []testRoundTrip{
|
|
||||||
{
|
|
||||||
expectBearerToken: goodToken.Encode(),
|
|
||||||
returnHTTPStatus: 200,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
wantStatus: 200,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Initial JWT is set, but it's expired, so it gets refreshed.
|
name: "invalid jwt",
|
||||||
cfgIDToken: expiredToken,
|
jwt: &jose.JWT{},
|
||||||
cfgRefreshToken: str("rt1"),
|
now: t0,
|
||||||
|
|
||||||
expectRefreshes: []testRefresh{
|
|
||||||
{
|
|
||||||
expectRefreshToken: "rt1",
|
|
||||||
returnTokens: oauth2.TokenResponse{
|
|
||||||
IDToken: goodToken.Encode(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectRequests: []testRoundTrip{
|
|
||||||
{
|
|
||||||
expectBearerToken: goodToken.Encode(),
|
|
||||||
returnHTTPStatus: 200,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectPersists: []testPersist{
|
|
||||||
{
|
|
||||||
cfg: map[string]string{
|
|
||||||
cfgIDToken: goodToken.Encode(),
|
|
||||||
cfgRefreshToken: "rt1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
wantStatus: 200,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Initial JWT is set, but it's expired, so it gets refreshed - this
|
|
||||||
// time the refresh token itself is also refreshed
|
|
||||||
cfgIDToken: expiredToken,
|
|
||||||
cfgRefreshToken: str("rt1"),
|
|
||||||
|
|
||||||
expectRefreshes: []testRefresh{
|
|
||||||
{
|
|
||||||
expectRefreshToken: "rt1",
|
|
||||||
returnTokens: oauth2.TokenResponse{
|
|
||||||
IDToken: goodToken.Encode(),
|
|
||||||
RefreshToken: "rt2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectRequests: []testRoundTrip{
|
|
||||||
{
|
|
||||||
expectBearerToken: goodToken.Encode(),
|
|
||||||
returnHTTPStatus: 200,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectPersists: []testPersist{
|
|
||||||
{
|
|
||||||
cfg: map[string]string{
|
|
||||||
cfgIDToken: goodToken.Encode(),
|
|
||||||
cfgRefreshToken: "rt2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
wantStatus: 200,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Initial JWT is not set, so it gets refreshed.
|
|
||||||
cfgRefreshToken: str("rt1"),
|
|
||||||
|
|
||||||
expectRefreshes: []testRefresh{
|
|
||||||
{
|
|
||||||
expectRefreshToken: "rt1",
|
|
||||||
returnTokens: oauth2.TokenResponse{
|
|
||||||
IDToken: goodToken.Encode(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectRequests: []testRoundTrip{
|
|
||||||
{
|
|
||||||
expectBearerToken: goodToken.Encode(),
|
|
||||||
returnHTTPStatus: 200,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectPersists: []testPersist{
|
|
||||||
{
|
|
||||||
cfg: map[string]string{
|
|
||||||
cfgIDToken: goodToken.Encode(),
|
|
||||||
cfgRefreshToken: "rt1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
wantStatus: 200,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Expired token, but no refresh token.
|
|
||||||
cfgIDToken: expiredToken,
|
|
||||||
|
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Initial JWT is not set, so it gets refreshed, but the server
|
name: "expired jwt",
|
||||||
// rejects it when it is used, so it refreshes again, which
|
jwt: makeToken("foo", t0.Add(-time.Hour), 1),
|
||||||
// succeeds.
|
now: t0,
|
||||||
cfgRefreshToken: str("rt1"),
|
wantExpired: true,
|
||||||
|
|
||||||
expectRefreshes: []testRefresh{
|
|
||||||
{
|
|
||||||
expectRefreshToken: "rt1",
|
|
||||||
returnTokens: oauth2.TokenResponse{
|
|
||||||
IDToken: goodToken.Encode(),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
expectRefreshToken: "rt1",
|
name: "jwt expires soon enough to be marked expired",
|
||||||
returnTokens: oauth2.TokenResponse{
|
jwt: makeToken("foo", t0, 1),
|
||||||
IDToken: goodToken2.Encode(),
|
now: t0,
|
||||||
},
|
wantExpired: true,
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectRequests: []testRoundTrip{
|
|
||||||
{
|
|
||||||
expectBearerToken: goodToken.Encode(),
|
|
||||||
returnHTTPStatus: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expectBearerToken: goodToken2.Encode(),
|
|
||||||
returnHTTPStatus: http.StatusOK,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectPersists: []testPersist{
|
|
||||||
{
|
|
||||||
cfg: map[string]string{
|
|
||||||
cfgIDToken: goodToken.Encode(),
|
|
||||||
cfgRefreshToken: "rt1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
cfg: map[string]string{
|
|
||||||
cfgIDToken: goodToken2.Encode(),
|
|
||||||
cfgRefreshToken: "rt1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
wantStatus: 200,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
// Initial JWT is but the server rejects it when it is used, so it
|
|
||||||
// refreshes again, which succeeds.
|
|
||||||
cfgRefreshToken: str("rt1"),
|
|
||||||
cfgIDToken: goodToken,
|
|
||||||
|
|
||||||
expectRefreshes: []testRefresh{
|
|
||||||
{
|
|
||||||
expectRefreshToken: "rt1",
|
|
||||||
returnTokens: oauth2.TokenResponse{
|
|
||||||
IDToken: goodToken2.Encode(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectRequests: []testRoundTrip{
|
|
||||||
{
|
|
||||||
expectBearerToken: goodToken.Encode(),
|
|
||||||
returnHTTPStatus: http.StatusUnauthorized,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expectBearerToken: goodToken2.Encode(),
|
|
||||||
returnHTTPStatus: http.StatusOK,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
expectPersists: []testPersist{
|
|
||||||
{
|
|
||||||
cfg: map[string]string{
|
|
||||||
cfgIDToken: goodToken2.Encode(),
|
|
||||||
cfgRefreshToken: "rt1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantStatus: 200,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for _, tc := range tests {
|
||||||
client := &testOIDCClient{
|
func() {
|
||||||
refreshes: tt.expectRefreshes,
|
valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode())
|
||||||
}
|
|
||||||
|
|
||||||
persister := &testPersister{
|
|
||||||
tt.expectPersists,
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := map[string]string{}
|
|
||||||
if tt.cfgIDToken != nil {
|
|
||||||
cfg[cfgIDToken] = tt.cfgIDToken.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.cfgRefreshToken != nil {
|
|
||||||
cfg[cfgRefreshToken] = *tt.cfgRefreshToken
|
|
||||||
}
|
|
||||||
|
|
||||||
ap := &oidcAuthProvider{
|
|
||||||
refresher: &idTokenRefresher{
|
|
||||||
client: client,
|
|
||||||
cfg: cfg,
|
|
||||||
persister: persister,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.cfgIDToken != nil {
|
|
||||||
ap.initialIDToken = *tt.cfgIDToken
|
|
||||||
}
|
|
||||||
|
|
||||||
tstRT := &testRoundTripper{
|
|
||||||
tt.expectRequests,
|
|
||||||
}
|
|
||||||
|
|
||||||
rt := ap.WrapTransport(tstRT)
|
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", "http://cluster.example.com", nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("case %d: unexpected error making request: %v", i, err)
|
if !tc.wantErr {
|
||||||
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tc.wantErr {
|
||||||
|
t.Errorf("%s: expected error", tc.name)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := rt.RoundTrip(req)
|
if valid && tc.wantExpired {
|
||||||
if tt.wantErr {
|
t.Errorf("%s: expected token to be expired", tc.name)
|
||||||
if err == nil {
|
|
||||||
t.Errorf("case %d: Expected non-nil error", i)
|
|
||||||
}
|
}
|
||||||
} else if err != nil {
|
if !valid && !tc.wantExpired {
|
||||||
t.Errorf("case %d: unexpected error making round trip: %v", i, err)
|
t.Errorf("%s: expected token to be valid", tc.name)
|
||||||
|
}
|
||||||
} else {
|
}()
|
||||||
if res.StatusCode != tt.wantStatus {
|
|
||||||
t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = client.verify(); err != nil {
|
func TestClientCache(t *testing.T) {
|
||||||
t.Errorf("case %d: %v", i, err)
|
cache := newClientCache()
|
||||||
|
|
||||||
|
if _, ok := cache.getClient("issuer1", "id1", "secret1"); ok {
|
||||||
|
t.Fatalf("got client before putting one in the cache")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = persister.verify(); err != nil {
|
cli1 := new(oidcAuthProvider)
|
||||||
t.Errorf("case %d: %v", i, err)
|
cli2 := new(oidcAuthProvider)
|
||||||
}
|
|
||||||
|
gotcli := cache.setClient("issuer1", "id1", "secret1", cli1)
|
||||||
if err = tstRT.verify(); err != nil {
|
if cli1 != gotcli {
|
||||||
t.Errorf("case %d: %v", i, err)
|
t.Fatalf("set first client and got a different one")
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gotcli = cache.setClient("issuer1", "id1", "secret1", cli2)
|
||||||
|
if cli1 != gotcli {
|
||||||
|
t.Fatalf("set a second client and didn't get the first")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type testRoundTrip struct {
|
|
||||||
expectBearerToken string
|
|
||||||
returnHTTPStatus int
|
|
||||||
}
|
|
||||||
|
|
||||||
type testRoundTripper struct {
|
|
||||||
trips []testRoundTrip
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
if len(t.trips) == 0 {
|
|
||||||
return nil, errors.New("unexpected RoundTrip call")
|
|
||||||
}
|
|
||||||
|
|
||||||
var trip testRoundTrip
|
|
||||||
trip, t.trips = t.trips[0], t.trips[1:]
|
|
||||||
|
|
||||||
var bt string
|
|
||||||
var parts []string
|
|
||||||
auth := strings.TrimSpace(req.Header.Get("Authorization"))
|
|
||||||
if auth == "" {
|
|
||||||
goto Compare
|
|
||||||
}
|
|
||||||
|
|
||||||
parts = strings.Split(auth, " ")
|
|
||||||
if len(parts) < 2 || strings.ToLower(parts[0]) != "bearer" {
|
|
||||||
goto Compare
|
|
||||||
}
|
|
||||||
|
|
||||||
bt = parts[1]
|
|
||||||
|
|
||||||
Compare:
|
|
||||||
if trip.expectBearerToken != bt {
|
|
||||||
return nil, fmt.Errorf("want bearerToken=%v, got=%v", trip.expectBearerToken, bt)
|
|
||||||
}
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: trip.returnHTTPStatus,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testRoundTripper) verify() error {
|
|
||||||
if l := len(t.trips); l > 0 {
|
|
||||||
return fmt.Errorf("%d uncalled round trips", l)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type testPersist struct {
|
|
||||||
cfg map[string]string
|
|
||||||
returnErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
type testPersister struct {
|
|
||||||
persists []testPersist
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testPersister) Persist(cfg map[string]string) error {
|
|
||||||
if len(t.persists) == 0 {
|
|
||||||
return errors.New("unexpected persist call")
|
|
||||||
}
|
|
||||||
|
|
||||||
var persist testPersist
|
|
||||||
persist, t.persists = t.persists[0], t.persists[1:]
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(persist.cfg, cfg) {
|
|
||||||
return fmt.Errorf("Unexpected cfg: %v", diff.ObjectDiff(persist.cfg, cfg))
|
|
||||||
}
|
|
||||||
|
|
||||||
return persist.returnErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testPersister) verify() error {
|
|
||||||
if l := len(t.persists); l > 0 {
|
|
||||||
return fmt.Errorf("%d uncalled persists", l)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type testRefresh struct {
|
|
||||||
expectRefreshToken string
|
|
||||||
|
|
||||||
returnErr error
|
|
||||||
returnTokens oauth2.TokenResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
type testOIDCClient struct {
|
|
||||||
refreshes []testRefresh
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *testOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
|
||||||
if len(o.refreshes) == 0 {
|
|
||||||
return oauth2.TokenResponse{}, errors.New("unexpected refresh request")
|
|
||||||
}
|
|
||||||
|
|
||||||
var refresh testRefresh
|
|
||||||
refresh, o.refreshes = o.refreshes[0], o.refreshes[1:]
|
|
||||||
|
|
||||||
if rt != refresh.expectRefreshToken {
|
|
||||||
return oauth2.TokenResponse{}, fmt.Errorf("want rt=%v, got=%v",
|
|
||||||
refresh.expectRefreshToken,
|
|
||||||
rt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if refresh.returnErr != nil {
|
|
||||||
return oauth2.TokenResponse{}, refresh.returnErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return refresh.returnTokens, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *testOIDCClient) verifyJWT(jwt jose.JWT) error {
|
|
||||||
claims, err := jwt.Claims()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
claim, _, _ := claims.StringClaim("test")
|
|
||||||
if claim != "good" {
|
|
||||||
return errors.New("bad token")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testOIDCClient) verify() error {
|
|
||||||
if l := len(t.refreshes); l > 0 {
|
|
||||||
return fmt.Errorf("%d uncalled refreshes", l)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareJWTs(a, b jose.JWT) string {
|
|
||||||
if a.Encode() == b.Encode() {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var aClaims, bClaims jose.Claims
|
|
||||||
for _, j := range []struct {
|
|
||||||
claims *jose.Claims
|
|
||||||
jwt jose.JWT
|
|
||||||
}{
|
|
||||||
{&aClaims, a},
|
|
||||||
{&bClaims, b},
|
|
||||||
} {
|
|
||||||
var err error
|
|
||||||
*j.claims, err = j.jwt.Claims()
|
|
||||||
if err != nil {
|
|
||||||
*j.claims = jose.Claims(map[string]interface{}{
|
|
||||||
"msg": "bad claims",
|
|
||||||
"err": err,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return diff.ObjectDiff(aClaims, bClaims)
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user