From 82bdf9051c432dbf1888bc9b70ab8a3bc27d175f Mon Sep 17 00:00:00 2001 From: Bobby Rullo Date: Wed, 4 May 2016 09:47:11 -0700 Subject: [PATCH 1/6] Update github.com/coreos/go-oidc --- Godeps/Godeps.json | 10 +++---- vendor/github.com/coreos/go-oidc/jose/sig.go | 3 +- .../coreos/go-oidc/jose/sig_hmac.go | 3 +- .../github.com/coreos/go-oidc/jose/sig_rsa.go | 3 +- vendor/github.com/coreos/go-oidc/key/key.go | 2 +- .../coreos/go-oidc/oauth2/oauth2.go | 30 +++++++++++++++++++ vendor/github.com/coreos/go-oidc/oidc/key.go | 12 +++++++- .../coreos/go-oidc/oidc/provider.go | 7 ++++- .../coreos/go-oidc/oidc/transport.go | 9 ++++++ 9 files changed, 65 insertions(+), 14 deletions(-) diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 7c16e912b3c..1670559a2b9 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -485,23 +485,23 @@ }, { "ImportPath": "github.com/coreos/go-oidc/http", - "Rev": "d7cb66526fffc811d602b6770581064f4b66b507" + "Rev": "5cf2aa52da8c574d3aa4458f471ad6ae2240fe6b" }, { "ImportPath": "github.com/coreos/go-oidc/jose", - "Rev": "d7cb66526fffc811d602b6770581064f4b66b507" + "Rev": "5cf2aa52da8c574d3aa4458f471ad6ae2240fe6b" }, { "ImportPath": "github.com/coreos/go-oidc/key", - "Rev": "d7cb66526fffc811d602b6770581064f4b66b507" + "Rev": "5cf2aa52da8c574d3aa4458f471ad6ae2240fe6b" }, { "ImportPath": "github.com/coreos/go-oidc/oauth2", - "Rev": "d7cb66526fffc811d602b6770581064f4b66b507" + "Rev": "5cf2aa52da8c574d3aa4458f471ad6ae2240fe6b" }, { "ImportPath": "github.com/coreos/go-oidc/oidc", - "Rev": "d7cb66526fffc811d602b6770581064f4b66b507" + "Rev": "5cf2aa52da8c574d3aa4458f471ad6ae2240fe6b" }, { "ImportPath": "github.com/coreos/go-semver/semver", diff --git a/vendor/github.com/coreos/go-oidc/jose/sig.go b/vendor/github.com/coreos/go-oidc/jose/sig.go index 220681bd758..7b2b253cca5 100644 --- a/vendor/github.com/coreos/go-oidc/jose/sig.go +++ b/vendor/github.com/coreos/go-oidc/jose/sig.go @@ -2,7 +2,6 @@ package jose import ( "fmt" - "strings" ) type Verifier interface { @@ -17,7 +16,7 @@ type Signer interface { } func NewVerifier(jwk JWK) (Verifier, error) { - if strings.ToUpper(jwk.Type) != "RSA" { + if jwk.Type != "RSA" { return nil, fmt.Errorf("unsupported key type %q", jwk.Type) } diff --git a/vendor/github.com/coreos/go-oidc/jose/sig_hmac.go b/vendor/github.com/coreos/go-oidc/jose/sig_hmac.go index bcf42b707ed..b3ca3ef3d49 100644 --- a/vendor/github.com/coreos/go-oidc/jose/sig_hmac.go +++ b/vendor/github.com/coreos/go-oidc/jose/sig_hmac.go @@ -7,7 +7,6 @@ import ( _ "crypto/sha256" "errors" "fmt" - "strings" ) type VerifierHMAC struct { @@ -21,7 +20,7 @@ type SignerHMAC struct { } func NewVerifierHMAC(jwk JWK) (*VerifierHMAC, error) { - if strings.ToUpper(jwk.Alg) != "HS256" { + if jwk.Alg != "" && jwk.Alg != "HS256" { return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg) } diff --git a/vendor/github.com/coreos/go-oidc/jose/sig_rsa.go b/vendor/github.com/coreos/go-oidc/jose/sig_rsa.go index 57066f02f10..004e45dd835 100644 --- a/vendor/github.com/coreos/go-oidc/jose/sig_rsa.go +++ b/vendor/github.com/coreos/go-oidc/jose/sig_rsa.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "crypto/rsa" "fmt" - "strings" ) type VerifierRSA struct { @@ -20,7 +19,7 @@ type SignerRSA struct { } func NewVerifierRSA(jwk JWK) (*VerifierRSA, error) { - if strings.ToUpper(jwk.Alg) != "RS256" { + if jwk.Alg != "" && jwk.Alg != "RS256" { return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg) } diff --git a/vendor/github.com/coreos/go-oidc/key/key.go b/vendor/github.com/coreos/go-oidc/key/key.go index de6250373d0..d0142a9e0e0 100644 --- a/vendor/github.com/coreos/go-oidc/key/key.go +++ b/vendor/github.com/coreos/go-oidc/key/key.go @@ -20,7 +20,7 @@ type PublicKey struct { } func (k *PublicKey) MarshalJSON() ([]byte, error) { - return json.Marshal(k.jwk) + return json.Marshal(&k.jwk) } func (k *PublicKey) UnmarshalJSON(data []byte) error { diff --git a/vendor/github.com/coreos/go-oidc/oauth2/oauth2.go b/vendor/github.com/coreos/go-oidc/oauth2/oauth2.go index 14bd6cd3f5e..1c68293a0a8 100644 --- a/vendor/github.com/coreos/go-oidc/oauth2/oauth2.go +++ b/vendor/github.com/coreos/go-oidc/oauth2/oauth2.go @@ -56,6 +56,7 @@ const ( const ( GrantTypeAuthCode = "authorization_code" GrantTypeClientCreds = "client_credentials" + GrantTypeUserCreds = "password" GrantTypeImplicit = "implicit" GrantTypeRefreshToken = "refresh_token" @@ -140,6 +141,11 @@ func NewClient(hc phttp.Client, cfg Config) (c *Client, err error) { return } +// Return the embedded HTTP client +func (c *Client) HttpClient() phttp.Client { + return c.hc +} + // Generate the url for initial redirect to oauth provider. func (c *Client) AuthCodeURL(state, accessType, prompt string) string { v := c.commonURLValues() @@ -220,6 +226,30 @@ func (c *Client) ClientCredsToken(scope []string) (result TokenResponse, err err return parseTokenResponse(resp) } +// UserCredsToken posts the username and password to obtain a token scoped to the OAuth2 client via the "password" grant_type +// May not be supported by all OAuth2 servers. +func (c *Client) UserCredsToken(username, password string) (result TokenResponse, err error) { + v := url.Values{ + "scope": {strings.Join(c.scope, " ")}, + "grant_type": {GrantTypeUserCreds}, + "username": {username}, + "password": {password}, + } + + req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v) + if err != nil { + return + } + + resp, err := c.hc.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + return parseTokenResponse(resp) +} + // RequestToken requests a token from the Token Endpoint with the specified grantType. // If 'grantType' == GrantTypeAuthCode, then 'value' should be the authorization code. // If 'grantType' == GrantTypeRefreshToken, then 'value' should be the refresh token. diff --git a/vendor/github.com/coreos/go-oidc/oidc/key.go b/vendor/github.com/coreos/go-oidc/oidc/key.go index 004310d445b..82a0f567d57 100644 --- a/vendor/github.com/coreos/go-oidc/oidc/key.go +++ b/vendor/github.com/coreos/go-oidc/oidc/key.go @@ -11,6 +11,11 @@ import ( "github.com/coreos/go-oidc/key" ) +// DefaultPublicKeySetTTL is the default TTL set on the PublicKeySet if no +// Cache-Control header is provided by the JWK Set document endpoint. +const DefaultPublicKeySetTTL = 24 * time.Hour + +// NewRemotePublicKeyRepo is responsible for fetching the JWK Set document. func NewRemotePublicKeyRepo(hc phttp.Client, ep string) *remotePublicKeyRepo { return &remotePublicKeyRepo{hc: hc, ep: ep} } @@ -20,6 +25,11 @@ type remotePublicKeyRepo struct { ep string } +// Get returns a PublicKeySet fetched from the JWK Set document endpoint. A TTL +// is set on the Key Set to avoid it having to be re-retrieved for every +// encryption event. This TTL is typically controlled by the endpoint returning +// a Cache-Control header, but defaults to 24 hours if no Cache-Control header +// is found. func (r *remotePublicKeyRepo) Get() (key.KeySet, error) { req, err := http.NewRequest("GET", r.ep, nil) if err != nil { @@ -48,7 +58,7 @@ func (r *remotePublicKeyRepo) Get() (key.KeySet, error) { return nil, err } if !ok { - return nil, errors.New("HTTP cache headers not set") + ttl = DefaultPublicKeySetTTL } exp := time.Now().UTC().Add(ttl) diff --git a/vendor/github.com/coreos/go-oidc/oidc/provider.go b/vendor/github.com/coreos/go-oidc/oidc/provider.go index 807cf00adec..1235890c0c2 100644 --- a/vendor/github.com/coreos/go-oidc/oidc/provider.go +++ b/vendor/github.com/coreos/go-oidc/oidc/provider.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/url" + "strings" "sync" "time" @@ -618,7 +619,11 @@ func NewHTTPProviderConfigGetter(hc phttp.Client, issuerURL string) *httpProvide } func (r *httpProviderConfigGetter) Get() (cfg ProviderConfig, err error) { - req, err := http.NewRequest("GET", r.issuerURL+discoveryConfigPath, nil) + // If the Issuer value contains a path component, any terminating / MUST be removed before + // appending /.well-known/openid-configuration. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest + discoveryURL := strings.TrimSuffix(r.issuerURL, "/") + discoveryConfigPath + req, err := http.NewRequest("GET", discoveryURL, nil) if err != nil { return } diff --git a/vendor/github.com/coreos/go-oidc/oidc/transport.go b/vendor/github.com/coreos/go-oidc/oidc/transport.go index 93ff9e1f570..61c926d7fe7 100644 --- a/vendor/github.com/coreos/go-oidc/oidc/transport.go +++ b/vendor/github.com/coreos/go-oidc/oidc/transport.go @@ -67,6 +67,15 @@ func (t *AuthenticatedTransport) verifiedJWT() (jose.JWT, error) { return t.jwt, nil } +// SetJWT sets the JWT held by the Transport. +// This is useful for cases in which you want to set an initial JWT. +func (t *AuthenticatedTransport) SetJWT(jwt jose.JWT) { + t.mu.Lock() + defer t.mu.Unlock() + + t.jwt = jwt +} + func (t *AuthenticatedTransport) RoundTrip(r *http.Request) (*http.Response, error) { jwt, err := t.verifiedJWT() if err != nil { From f2135bdf90cbe72ab599d52e101132a0eb9408db Mon Sep 17 00:00:00 2001 From: Bobby Rullo Date: Fri, 6 May 2016 10:33:51 -0700 Subject: [PATCH 2/6] Implement new OIDC client AuthProvider This commit handles: * Passing ID Token as Bearer token * Refreshing of tokens using refresh-tokens * Persisting refreshed tokens * ability to add arbitrary extra scopes via config * this is what enables the cross-client/azp stuff --- plugin/pkg/client/auth/oidc/OWNERS | 2 + plugin/pkg/client/auth/oidc/oidc.go | 196 ++++++++++++++++++++++++++++ plugin/pkg/client/auth/plugins.go | 1 + 3 files changed, 199 insertions(+) create mode 100644 plugin/pkg/client/auth/oidc/OWNERS create mode 100644 plugin/pkg/client/auth/oidc/oidc.go diff --git a/plugin/pkg/client/auth/oidc/OWNERS b/plugin/pkg/client/auth/oidc/OWNERS new file mode 100644 index 00000000000..ecf33499349 --- /dev/null +++ b/plugin/pkg/client/auth/oidc/OWNERS @@ -0,0 +1,2 @@ +assignees: + - bobbyrullo diff --git a/plugin/pkg/client/auth/oidc/oidc.go b/plugin/pkg/client/auth/oidc/oidc.go new file mode 100644 index 00000000000..c752ba2b4a8 --- /dev/null +++ b/plugin/pkg/client/auth/oidc/oidc.go @@ -0,0 +1,196 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package oidc + +import ( + "encoding/base64" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/coreos/go-oidc/jose" + "github.com/coreos/go-oidc/oauth2" + "github.com/coreos/go-oidc/oidc" + "github.com/golang/glog" + + "k8s.io/kubernetes/pkg/client/restclient" +) + +const ( + cfgIssuerUrl = "idp-issuer-url" + cfgClientId = "client-id" + cfgClientSecret = "client-secret" + cfgCertificateAuthority = "idp-certificate-authority" + cfgCertificateAuthorityData = "idp-certificate-authority-data" + cfgExtraScopes = "extra-scopes" + cfgIdToken = "id-token" + cfgRefreshToken = "refresh-token" +) + +func init() { + if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil { + glog.Fatalf("Failed to register oidc auth plugin: %v", err) + } +} + +func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) { + var certAuthData []byte + var err error + if cfg[cfgCertificateAuthorityData] != "" { + certAuthData, err = base64.StdEncoding.DecodeString(cfg[cfgCertificateAuthorityData]) + if err != nil { + return nil, err + } + } + + clientConfig := restclient.Config{ + TLSClientConfig: restclient.TLSClientConfig{ + CAFile: cfg[cfgCertificateAuthority], + CAData: certAuthData, + }, + } + + trans, err := restclient.TransportFor(&clientConfig) + if err != nil { + return nil, err + } + hc := &http.Client{Transport: trans} + + issuer, ok := cfg[cfgIssuerUrl] + if !ok || issuer == "" { + return nil, errors.New("Must provide idp-issuer-url") + } + + providerCfg, err := oidc.FetchProviderConfig(hc, strings.TrimSuffix(issuer, "/")) + if err != nil { + return nil, fmt.Errorf("error fetching provider config: %v", err) + } + + scopes := strings.Split(cfg[cfgExtraScopes], ",") + oidcCfg := oidc.ClientConfig{ + HTTPClient: hc, + Credentials: oidc.ClientCredentials{ + ID: cfg[cfgClientId], + Secret: cfg[cfgClientSecret], + }, + ProviderConfig: providerCfg, + Scope: scopes, + } + + client, err := oidc.NewClient(oidcCfg) + if err != nil { + return nil, fmt.Errorf("error creating OIDC Client: %v", err) + } + + var initialIDToken jose.JWT + if cfg[cfgIdToken] != "" { + initialIDToken, err = jose.ParseJWT(cfg[cfgIdToken]) + if err != nil { + return nil, err + } + } + + return &oidcAuthProvider{ + intialIDToken: initialIDToken, + refresher: &idTokenRefresher{ + client: client, + cfg: cfg, + persister: persister, + }, + }, nil +} + +type oidcAuthProvider struct { + refresher *idTokenRefresher + intialIDToken jose.JWT +} + +func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper { + at := &oidc.AuthenticatedTransport{ + TokenRefresher: g.refresher, + RoundTripper: rt, + } + at.SetJWT(g.intialIDToken) + return at +} + +func (g *oidcAuthProvider) Login() error { + return errors.New("not yet implemented") +} + +type idTokenRefresher struct { + cfg map[string]string + client *oidc.Client + persister restclient.AuthProviderConfigPersister + intialIDToken jose.JWT +} + +func (r *idTokenRefresher) Verify(jwt jose.JWT) error { + claims, err := jwt.Claims() + if err != nil { + return err + } + + now := time.Now() + exp, ok, err := claims.TimeClaim("exp") + switch { + case err != nil: + return fmt.Errorf("failed to parse 'exp' claim: %v", err) + case !ok: + return errors.New("missing required 'exp' claim") + case exp.Before(now): + return fmt.Errorf("token already expired at: %v", exp) + } + + return nil +} + +func (r *idTokenRefresher) Refresh() (jose.JWT, error) { + rt, ok := r.cfg[cfgRefreshToken] + if !ok { + return jose.JWT{}, errors.New("No valid id-token, and cannot refresh without refresh-token") + } + + oac, err := r.client.OAuthClient() + if err != nil { + return jose.JWT{}, err + } + + tokens, err := oac.RequestToken(oauth2.GrantTypeRefreshToken, rt) + if err != nil { + return jose.JWT{}, err + } + + jwt, err := jose.ParseJWT(tokens.IDToken) + if err != nil { + return jose.JWT{}, err + } + + if tokens.RefreshToken != "" && tokens.RefreshToken != rt { + r.cfg[cfgRefreshToken] = tokens.RefreshToken + } + r.cfg[cfgIdToken] = jwt.Encode() + + err = r.persister.Persist(r.cfg) + if err != nil { + return jose.JWT{}, fmt.Errorf("could not perist new tokens: %v", err) + } + + return jwt, r.client.VerifyJWT(jwt) +} diff --git a/plugin/pkg/client/auth/plugins.go b/plugin/pkg/client/auth/plugins.go index c93cfd1d939..2b422ddda02 100644 --- a/plugin/pkg/client/auth/plugins.go +++ b/plugin/pkg/client/auth/plugins.go @@ -19,4 +19,5 @@ package plugins import ( // Initialize all known client auth plugins. _ "k8s.io/kubernetes/plugin/pkg/client/auth/gcp" + _ "k8s.io/kubernetes/plugin/pkg/client/auth/oidc" ) From c990462d0fb80a5d03bb0345226028ab8c0aa265 Mon Sep 17 00:00:00 2001 From: Bobby Rullo Date: Wed, 11 May 2016 15:09:42 -0700 Subject: [PATCH 3/6] Refactor test oidc provider into its own package This makes it easier to test other OIDC code. --- .../authenticator/token/oidc/oidc_test.go | 210 +++--------------- .../token/oidc/testing/provider.go | 172 ++++++++++++++ plugin/pkg/client/auth/oidc/oidc_test.go | 75 +++++++ 3 files changed, 283 insertions(+), 174 deletions(-) create mode 100644 plugin/pkg/auth/authenticator/token/oidc/testing/provider.go create mode 100644 plugin/pkg/client/auth/oidc/oidc_test.go diff --git a/plugin/pkg/auth/authenticator/token/oidc/oidc_test.go b/plugin/pkg/auth/authenticator/token/oidc/oidc_test.go index 3836adeb84d..5450052a800 100644 --- a/plugin/pkg/auth/authenticator/token/oidc/oidc_test.go +++ b/plugin/pkg/auth/authenticator/token/oidc/oidc_test.go @@ -17,60 +17,23 @@ limitations under the License. package oidc import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/json" - "encoding/pem" "fmt" - "io/ioutil" - "math/big" - "net" - "net/http" "net/http/httptest" "net/url" "os" "path" - "path/filepath" "reflect" "strings" "testing" "time" "github.com/coreos/go-oidc/jose" - "github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/oidc" + "k8s.io/kubernetes/pkg/auth/user" + oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing" ) -type oidcProvider struct { - mux *http.ServeMux - pcfg oidc.ProviderConfig - privKey *key.PrivateKey -} - -func newOIDCProvider(t *testing.T) *oidcProvider { - privKey, err := key.GeneratePrivateKey() - if err != nil { - t.Fatalf("Cannot create OIDC Provider: %v", err) - return nil - } - - op := &oidcProvider{ - mux: http.NewServeMux(), - privKey: privKey, - } - - op.mux.HandleFunc("/.well-known/openid-configuration", op.handleConfig) - op.mux.HandleFunc("/keys", op.handleKeys) - - return op - -} - func mustParseURL(t *testing.T, s string) *url.URL { u, err := url.Parse(s) if err != nil { @@ -79,37 +42,8 @@ func mustParseURL(t *testing.T, s string) *url.URL { return u } -func (op *oidcProvider) handleConfig(w http.ResponseWriter, req *http.Request) { - b, err := json.Marshal(&op.pcfg) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - w.Write(b) -} - -func (op *oidcProvider) handleKeys(w http.ResponseWriter, req *http.Request) { - keys := struct { - Keys []jose.JWK `json:"keys"` - }{ - Keys: []jose.JWK{op.privKey.JWK()}, - } - - b, err := json.Marshal(keys) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(time.Hour.Seconds()))) - w.Header().Set("Expires", time.Now().Add(time.Hour).Format(time.RFC1123)) - w.Header().Set("Content-Type", "application/json") - w.Write(b) -} - -func (op *oidcProvider) generateToken(t *testing.T, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string, iat, exp time.Time) string { - signer := op.privKey.Signer() +func generateToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string, iat, exp time.Time) string { + signer := op.PrivKey.Signer() claims := oidc.NewClaims(iss, sub, aud, iat, exp) claims.Add(usernameClaim, value) if groups != nil && groupsClaim != "" { @@ -124,79 +58,16 @@ func (op *oidcProvider) generateToken(t *testing.T, iss, sub, aud string, userna return jwt.Encode() } -func (op *oidcProvider) generateGoodToken(t *testing.T, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string { - return op.generateToken(t, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now(), time.Now().Add(time.Hour)) +func generateGoodToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string { + return generateToken(t, op, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now(), time.Now().Add(time.Hour)) } -func (op *oidcProvider) generateMalformedToken(t *testing.T, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string { - return op.generateToken(t, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now(), time.Now().Add(time.Hour)) + "randombits" +func generateMalformedToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string { + return generateToken(t, op, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now(), time.Now().Add(time.Hour)) + "randombits" } -func (op *oidcProvider) generateExpiredToken(t *testing.T, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string { - return op.generateToken(t, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour)) -} - -// generateSelfSignedCert generates a self-signed cert/key pairs and writes to the certPath/keyPath. -// This method is mostly identical to crypto.GenerateSelfSignedCert except for the 'IsCA' and 'KeyUsage' -// in the certificate template. (Maybe we can merge these two methods). -func generateSelfSignedCert(t *testing.T, host, certPath, keyPath string) { - priv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatal(err) - } - - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()), - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour * 24 * 365), - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - IsCA: true, - } - - if ip := net.ParseIP(host); ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } else { - template.DNSNames = append(template.DNSNames, host) - } - - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - t.Fatal(err) - } - - // Generate cert - certBuffer := bytes.Buffer{} - if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { - t.Fatal(err) - } - - // Generate key - keyBuffer := bytes.Buffer{} - if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil { - t.Fatal(err) - } - - // Write cert - if err := os.MkdirAll(filepath.Dir(certPath), os.FileMode(0755)); err != nil { - t.Fatal(err) - } - if err := ioutil.WriteFile(certPath, certBuffer.Bytes(), os.FileMode(0644)); err != nil { - t.Fatal(err) - } - - // Write key - if err := os.MkdirAll(filepath.Dir(keyPath), os.FileMode(0755)); err != nil { - t.Fatal(err) - } - if err := ioutil.WriteFile(keyPath, keyBuffer.Bytes(), os.FileMode(0600)); err != nil { - t.Fatal(err) - } +func generateExpiredToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string) string { + return generateToken(t, op, iss, sub, aud, usernameClaim, value, groupsClaim, groups, time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour)) } func TestOIDCDiscoveryTimeout(t *testing.T) { @@ -217,19 +88,16 @@ func TestOIDCDiscoveryNoKeyEndpoint(t *testing.T) { defer os.Remove(cert) defer os.Remove(key) - generateSelfSignedCert(t, "127.0.0.1", cert, key) + oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key) - op := newOIDCProvider(t) - srv := httptest.NewUnstartedServer(op.mux) - srv.TLS = &tls.Config{Certificates: make([]tls.Certificate, 1)} - srv.TLS.Certificates[0], err = tls.LoadX509KeyPair(cert, key) + op := oidctesting.NewOIDCProvider(t) + srv, err := op.ServeTLSWithKeyPair(cert, key) if err != nil { - t.Fatalf("Cannot load cert/key pair: %v", err) + t.Fatalf("Cannot start server %v", err) } - srv.StartTLS() defer srv.Close() - op.pcfg = oidc.ProviderConfig{ + op.PCFG = oidc.ProviderConfig{ Issuer: mustParseURL(t, srv.URL), // An invalid ProviderConfig. Keys endpoint is required. } @@ -241,11 +109,11 @@ func TestOIDCDiscoveryNoKeyEndpoint(t *testing.T) { func TestOIDCDiscoverySecureConnection(t *testing.T) { // Verify that plain HTTP issuer URL is forbidden. - op := newOIDCProvider(t) - srv := httptest.NewServer(op.mux) + op := oidctesting.NewOIDCProvider(t) + srv := httptest.NewServer(op.Mux) defer srv.Close() - op.pcfg = oidc.ProviderConfig{ + op.PCFG = oidc.ProviderConfig{ Issuer: mustParseURL(t, srv.URL), KeysEndpoint: mustParseURL(t, srv.URL+"/keys"), } @@ -268,20 +136,17 @@ func TestOIDCDiscoverySecureConnection(t *testing.T) { defer os.Remove(cert2) defer os.Remove(key2) - generateSelfSignedCert(t, "127.0.0.1", cert1, key1) - generateSelfSignedCert(t, "127.0.0.1", cert2, key2) + oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert1, key1) + oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert2, key2) // Create a TLS server using cert/key pair 1. - tlsSrv := httptest.NewUnstartedServer(op.mux) - tlsSrv.TLS = &tls.Config{Certificates: make([]tls.Certificate, 1)} - tlsSrv.TLS.Certificates[0], err = tls.LoadX509KeyPair(cert1, key1) + tlsSrv, err := op.ServeTLSWithKeyPair(cert1, key1) if err != nil { - t.Fatalf("Cannot load cert/key pair: %v", err) + t.Fatalf("Cannot start server: %v", err) } - tlsSrv.StartTLS() defer tlsSrv.Close() - op.pcfg = oidc.ProviderConfig{ + op.PCFG = oidc.ProviderConfig{ Issuer: mustParseURL(t, tlsSrv.URL), KeysEndpoint: mustParseURL(t, tlsSrv.URL+"/keys"), } @@ -303,21 +168,18 @@ func TestOIDCAuthentication(t *testing.T) { defer os.Remove(cert) defer os.Remove(key) - generateSelfSignedCert(t, "127.0.0.1", cert, key) + oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key) // Create a TLS server and a client. - op := newOIDCProvider(t) - srv := httptest.NewUnstartedServer(op.mux) - srv.TLS = &tls.Config{Certificates: make([]tls.Certificate, 1)} - srv.TLS.Certificates[0], err = tls.LoadX509KeyPair(cert, key) + op := oidctesting.NewOIDCProvider(t) + srv, err := op.ServeTLSWithKeyPair(cert, key) if err != nil { - t.Fatalf("Cannot load cert/key pair: %v", err) + t.Fatalf("Cannot start server: %v", err) } - srv.StartTLS() defer srv.Close() // A provider config with all required fields. - op.pcfg = oidc.ProviderConfig{ + op.PCFG = oidc.ProviderConfig{ Issuer: mustParseURL(t, srv.URL), AuthEndpoint: mustParseURL(t, srv.URL+"/auth"), TokenEndpoint: mustParseURL(t, srv.URL+"/token"), @@ -338,7 +200,7 @@ func TestOIDCAuthentication(t *testing.T) { { "sub", "", - op.generateGoodToken(t, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil), + generateGoodToken(t, op, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil), &user.DefaultInfo{Name: fmt.Sprintf("%s#%s", srv.URL, "user-foo")}, true, "", @@ -347,7 +209,7 @@ func TestOIDCAuthentication(t *testing.T) { // Use user defined claim (email here). "email", "", - op.generateGoodToken(t, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "", nil), + generateGoodToken(t, op, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "", nil), &user.DefaultInfo{Name: "foo@example.com"}, true, "", @@ -356,7 +218,7 @@ func TestOIDCAuthentication(t *testing.T) { // Use user defined claim (email here). "email", "", - op.generateGoodToken(t, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "groups", []string{"group1", "group2"}), + generateGoodToken(t, op, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "groups", []string{"group1", "group2"}), &user.DefaultInfo{Name: "foo@example.com"}, true, "", @@ -365,7 +227,7 @@ func TestOIDCAuthentication(t *testing.T) { // Use user defined claim (email here). "email", "groups", - op.generateGoodToken(t, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "groups", []string{"group1", "group2"}), + generateGoodToken(t, op, srv.URL, "client-foo", "client-foo", "email", "foo@example.com", "groups", []string{"group1", "group2"}), &user.DefaultInfo{Name: "foo@example.com", Groups: []string{"group1", "group2"}}, true, "", @@ -373,7 +235,7 @@ func TestOIDCAuthentication(t *testing.T) { { "sub", "", - op.generateMalformedToken(t, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil), + generateMalformedToken(t, op, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil), nil, false, "oidc: unable to verify JWT signature: no matching keys", @@ -382,7 +244,7 @@ func TestOIDCAuthentication(t *testing.T) { // Invalid 'aud'. "sub", "", - op.generateGoodToken(t, srv.URL, "client-foo", "client-bar", "sub", "user-foo", "", nil), + generateGoodToken(t, op, srv.URL, "client-foo", "client-bar", "sub", "user-foo", "", nil), nil, false, "oidc: JWT claims invalid: invalid claims, 'aud' claim and 'client_id' do not match", @@ -391,7 +253,7 @@ func TestOIDCAuthentication(t *testing.T) { // Invalid issuer. "sub", "", - op.generateGoodToken(t, "http://foo-bar.com", "client-foo", "client-foo", "sub", "user-foo", "", nil), + generateGoodToken(t, op, "http://foo-bar.com", "client-foo", "client-foo", "sub", "user-foo", "", nil), nil, false, "oidc: JWT claims invalid: invalid claim value: 'iss'.", @@ -399,7 +261,7 @@ func TestOIDCAuthentication(t *testing.T) { { "sub", "", - op.generateExpiredToken(t, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil), + generateExpiredToken(t, op, srv.URL, "client-foo", "client-foo", "sub", "user-foo", "", nil), nil, false, "oidc: JWT claims invalid: token is expired", diff --git a/plugin/pkg/auth/authenticator/token/oidc/testing/provider.go b/plugin/pkg/auth/authenticator/token/oidc/testing/provider.go new file mode 100644 index 00000000000..5396b1dde3a --- /dev/null +++ b/plugin/pkg/auth/authenticator/token/oidc/testing/provider.go @@ -0,0 +1,172 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/coreos/go-oidc/jose" + "github.com/coreos/go-oidc/key" + "github.com/coreos/go-oidc/oidc" +) + +// NewOIDCProvider provides a bare minimum OIDC IdP Server useful for testing. +func NewOIDCProvider(t *testing.T) *OIDCProvider { + privKey, err := key.GeneratePrivateKey() + if err != nil { + t.Fatalf("Cannot create OIDC Provider: %v", err) + return nil + } + + op := &OIDCProvider{ + Mux: http.NewServeMux(), + PrivKey: privKey, + } + + op.Mux.HandleFunc("/.well-known/openid-configuration", op.handleConfig) + op.Mux.HandleFunc("/keys", op.handleKeys) + + return op +} + +type OIDCProvider struct { + Mux *http.ServeMux + PCFG oidc.ProviderConfig + PrivKey *key.PrivateKey +} + +func (op *OIDCProvider) ServeTLSWithKeyPair(cert, key string) (*httptest.Server, error) { + srv := httptest.NewUnstartedServer(op.Mux) + + srv.TLS = &tls.Config{Certificates: make([]tls.Certificate, 1)} + var err error + srv.TLS.Certificates[0], err = tls.LoadX509KeyPair(cert, key) + if err != nil { + return nil, fmt.Errorf("Cannot load cert/key pair: %v", err) + } + srv.StartTLS() + return srv, nil +} + +func (op *OIDCProvider) handleConfig(w http.ResponseWriter, req *http.Request) { + b, err := json.Marshal(&op.PCFG) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(b) +} + +func (op *OIDCProvider) handleKeys(w http.ResponseWriter, req *http.Request) { + keys := struct { + Keys []jose.JWK `json:"keys"` + }{ + Keys: []jose.JWK{op.PrivKey.JWK()}, + } + + b, err := json.Marshal(keys) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(time.Hour.Seconds()))) + w.Header().Set("Expires", time.Now().Add(time.Hour).Format(time.RFC1123)) + w.Header().Set("Content-Type", "application/json") + w.Write(b) +} + +// generateSelfSignedCert generates a self-signed cert/key pairs and writes to the certPath/keyPath. +// This method is mostly identical to crypto.GenerateSelfSignedCert except for the 'IsCA' and 'KeyUsage' +// in the certificate template. (Maybe we can merge these two methods). +func GenerateSelfSignedCert(t *testing.T, host, certPath, keyPath string) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()), + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 365), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + + if ip := net.ParseIP(host); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, host) + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + + // Generate cert + certBuffer := bytes.Buffer{} + if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + t.Fatal(err) + } + + // Generate key + keyBuffer := bytes.Buffer{} + if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil { + t.Fatal(err) + } + + // Write cert + if err := os.MkdirAll(filepath.Dir(certPath), os.FileMode(0755)); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(certPath, certBuffer.Bytes(), os.FileMode(0644)); err != nil { + t.Fatal(err) + } + + // Write key + if err := os.MkdirAll(filepath.Dir(keyPath), os.FileMode(0755)); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(keyPath, keyBuffer.Bytes(), os.FileMode(0600)); err != nil { + t.Fatal(err) + } +} diff --git a/plugin/pkg/client/auth/oidc/oidc_test.go b/plugin/pkg/client/auth/oidc/oidc_test.go new file mode 100644 index 00000000000..cba7395ebc0 --- /dev/null +++ b/plugin/pkg/client/auth/oidc/oidc_test.go @@ -0,0 +1,75 @@ +package oidc + +import ( + "testing" + + "k8s.io/kubernetes/pkg/util/diff" + + "github.com/coreos/go-oidc/jose" +) + +func TestNewOIDCAuthProvider(t *testing.T) { + tests := []struct { + cfg map[string]string + + wantErr bool + wantInitialIDToken jose.JWT + }{ + { + cfg: map[string]string{ + cfgIssuerUrl: "auth.example.com", + }, + }, + } + + for i, tt := range tests { + ap, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, nil) + if tt.wantErr { + if err == nil { + t.Errorf("case %d: want non-nil err", i) + continue + } + } + + if err != nil { + t.Errorf("case %d: unexpected error on newOIDCAuthProvider: %v", i, err) + continue + } + + oidcAP, ok := ap.(*oidcAuthProvider) + if !ok { + t.Errorf("case %d: expected ap to be an oidcAuthProvider", i) + continue + } + + if diff := compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken); diff != "" { + t.Errorf("case %d: compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken)=%v", i, diff) + } + } +} + +func compareJWTs(a, b jose.JWT) string { + if a.Encode() == b.Encode() { + return "" + } + + var aClaims, bClaims jose.Claims + for _, j := range []struct { + claims *jose.Claims + jwt jose.JWT + }{ + {&aClaims, a}, + {&bClaims, b}, + } { + var err error + *j.claims, err = j.jwt.Claims() + if err != nil { + *j.claims = jose.Claims(map[string]interface{}{ + "msg": "bad claims", + "err": err, + }) + } + } + + return diff.ObjectDiff(a, b) +} From e85940ed179f8be36f2333d13f3f1297bb7d9178 Mon Sep 17 00:00:00 2001 From: Bobby Rullo Date: Wed, 11 May 2016 16:49:22 -0700 Subject: [PATCH 4/6] add tests for newOIDCAuthProvider --- .../authenticator/token/oidc/oidc_test.go | 29 +---- .../token/oidc/testing/provider.go | 22 ++++ plugin/pkg/client/auth/oidc/oidc.go | 42 ++++--- plugin/pkg/client/auth/oidc/oidc_test.go | 116 +++++++++++++++++- 4 files changed, 166 insertions(+), 43 deletions(-) diff --git a/plugin/pkg/auth/authenticator/token/oidc/oidc_test.go b/plugin/pkg/auth/authenticator/token/oidc/oidc_test.go index 5450052a800..f0dab980aa8 100644 --- a/plugin/pkg/auth/authenticator/token/oidc/oidc_test.go +++ b/plugin/pkg/auth/authenticator/token/oidc/oidc_test.go @@ -19,7 +19,6 @@ package oidc import ( "fmt" "net/http/httptest" - "net/url" "os" "path" "reflect" @@ -34,14 +33,6 @@ import ( oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing" ) -func mustParseURL(t *testing.T, s string) *url.URL { - u, err := url.Parse(s) - if err != nil { - t.Fatalf("Failed to parse url: %v", err) - } - return u -} - func generateToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string, iat, exp time.Time) string { signer := op.PrivKey.Signer() claims := oidc.NewClaims(iss, sub, aud, iat, exp) @@ -98,7 +89,7 @@ func TestOIDCDiscoveryNoKeyEndpoint(t *testing.T) { defer srv.Close() op.PCFG = oidc.ProviderConfig{ - Issuer: mustParseURL(t, srv.URL), // An invalid ProviderConfig. Keys endpoint is required. + Issuer: oidctesting.MustParseURL(srv.URL), // An invalid ProviderConfig. Keys endpoint is required. } _, err = New(OIDCOptions{srv.URL, "client-foo", cert, "sub", "", 0, 0}) @@ -114,8 +105,8 @@ func TestOIDCDiscoverySecureConnection(t *testing.T) { defer srv.Close() op.PCFG = oidc.ProviderConfig{ - Issuer: mustParseURL(t, srv.URL), - KeysEndpoint: mustParseURL(t, srv.URL+"/keys"), + Issuer: oidctesting.MustParseURL(srv.URL), + KeysEndpoint: oidctesting.MustParseURL(srv.URL + "/keys"), } expectErr := fmt.Errorf("'oidc-issuer-url' (%q) has invalid scheme (%q), require 'https'", srv.URL, "http") @@ -147,8 +138,8 @@ func TestOIDCDiscoverySecureConnection(t *testing.T) { defer tlsSrv.Close() op.PCFG = oidc.ProviderConfig{ - Issuer: mustParseURL(t, tlsSrv.URL), - KeysEndpoint: mustParseURL(t, tlsSrv.URL+"/keys"), + Issuer: oidctesting.MustParseURL(tlsSrv.URL), + KeysEndpoint: oidctesting.MustParseURL(tlsSrv.URL + "/keys"), } // Create a client using cert2, should fail. @@ -179,15 +170,7 @@ func TestOIDCAuthentication(t *testing.T) { defer srv.Close() // A provider config with all required fields. - op.PCFG = oidc.ProviderConfig{ - Issuer: mustParseURL(t, srv.URL), - AuthEndpoint: mustParseURL(t, srv.URL+"/auth"), - TokenEndpoint: mustParseURL(t, srv.URL+"/token"), - KeysEndpoint: mustParseURL(t, srv.URL+"/keys"), - ResponseTypesSupported: []string{"code"}, - SubjectTypesSupported: []string{"public"}, - IDTokenSigningAlgValues: []string{"RS256"}, - } + op.AddMinimalProviderConfig(srv) tests := []struct { userClaim string diff --git a/plugin/pkg/auth/authenticator/token/oidc/testing/provider.go b/plugin/pkg/auth/authenticator/token/oidc/testing/provider.go index 5396b1dde3a..ee2735d539a 100644 --- a/plugin/pkg/auth/authenticator/token/oidc/testing/provider.go +++ b/plugin/pkg/auth/authenticator/token/oidc/testing/provider.go @@ -31,6 +31,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "testing" @@ -79,6 +80,19 @@ func (op *OIDCProvider) ServeTLSWithKeyPair(cert, key string) (*httptest.Server, return srv, nil } +func (op *OIDCProvider) AddMinimalProviderConfig(srv *httptest.Server) { + op.PCFG = oidc.ProviderConfig{ + Issuer: MustParseURL(srv.URL), + AuthEndpoint: MustParseURL(srv.URL + "/auth"), + TokenEndpoint: MustParseURL(srv.URL + "/token"), + KeysEndpoint: MustParseURL(srv.URL + "/keys"), + ResponseTypesSupported: []string{"code"}, + SubjectTypesSupported: []string{"public"}, + IDTokenSigningAlgValues: []string{"RS256"}, + } + +} + func (op *OIDCProvider) handleConfig(w http.ResponseWriter, req *http.Request) { b, err := json.Marshal(&op.PCFG) if err != nil { @@ -108,6 +122,14 @@ func (op *OIDCProvider) handleKeys(w http.ResponseWriter, req *http.Request) { w.Write(b) } +func MustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Errorf("Failed to parse url: %v", err)) + } + return u +} + // generateSelfSignedCert generates a self-signed cert/key pairs and writes to the certPath/keyPath. // This method is mostly identical to crypto.GenerateSelfSignedCert except for the 'IsCA' and 'KeyUsage' // in the certificate template. (Maybe we can merge these two methods). diff --git a/plugin/pkg/client/auth/oidc/oidc.go b/plugin/pkg/client/auth/oidc/oidc.go index c752ba2b4a8..bf37d569b8e 100644 --- a/plugin/pkg/client/auth/oidc/oidc.go +++ b/plugin/pkg/client/auth/oidc/oidc.go @@ -34,12 +34,12 @@ import ( const ( cfgIssuerUrl = "idp-issuer-url" - cfgClientId = "client-id" + cfgClientID = "client-id" cfgClientSecret = "client-secret" cfgCertificateAuthority = "idp-certificate-authority" cfgCertificateAuthorityData = "idp-certificate-authority-data" cfgExtraScopes = "extra-scopes" - cfgIdToken = "id-token" + cfgIDToken = "id-token" cfgRefreshToken = "refresh-token" ) @@ -50,6 +50,21 @@ func init() { } func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) { + issuer := cfg[cfgIssuerUrl] + if issuer == "" { + return nil, fmt.Errorf("Must provide %s", cfgIssuerUrl) + } + + clientID := cfg[cfgClientID] + if clientID == "" { + return nil, fmt.Errorf("Must provide %s", cfgClientID) + } + + clientSecret := cfg[cfgClientSecret] + if clientSecret == "" { + return nil, fmt.Errorf("Must provide %s", cfgClientSecret) + } + var certAuthData []byte var err error if cfg[cfgCertificateAuthorityData] != "" { @@ -72,11 +87,6 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A } hc := &http.Client{Transport: trans} - issuer, ok := cfg[cfgIssuerUrl] - if !ok || issuer == "" { - return nil, errors.New("Must provide idp-issuer-url") - } - providerCfg, err := oidc.FetchProviderConfig(hc, strings.TrimSuffix(issuer, "/")) if err != nil { return nil, fmt.Errorf("error fetching provider config: %v", err) @@ -86,8 +96,8 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A oidcCfg := oidc.ClientConfig{ HTTPClient: hc, Credentials: oidc.ClientCredentials{ - ID: cfg[cfgClientId], - Secret: cfg[cfgClientSecret], + ID: clientID, + Secret: clientSecret, }, ProviderConfig: providerCfg, Scope: scopes, @@ -99,15 +109,15 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A } var initialIDToken jose.JWT - if cfg[cfgIdToken] != "" { - initialIDToken, err = jose.ParseJWT(cfg[cfgIdToken]) + if cfg[cfgIDToken] != "" { + initialIDToken, err = jose.ParseJWT(cfg[cfgIDToken]) if err != nil { return nil, err } } return &oidcAuthProvider{ - intialIDToken: initialIDToken, + initialIDToken: initialIDToken, refresher: &idTokenRefresher{ client: client, cfg: cfg, @@ -117,8 +127,8 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A } type oidcAuthProvider struct { - refresher *idTokenRefresher - intialIDToken jose.JWT + refresher *idTokenRefresher + initialIDToken jose.JWT } func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper { @@ -126,7 +136,7 @@ func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper TokenRefresher: g.refresher, RoundTripper: rt, } - at.SetJWT(g.intialIDToken) + at.SetJWT(g.initialIDToken) return at } @@ -185,7 +195,7 @@ func (r *idTokenRefresher) Refresh() (jose.JWT, error) { if tokens.RefreshToken != "" && tokens.RefreshToken != rt { r.cfg[cfgRefreshToken] = tokens.RefreshToken } - r.cfg[cfgIdToken] = jwt.Encode() + r.cfg[cfgIDToken] = jwt.Encode() err = r.persister.Persist(r.cfg) if err != nil { diff --git a/plugin/pkg/client/auth/oidc/oidc_test.go b/plugin/pkg/client/auth/oidc/oidc_test.go index cba7395ebc0..767781e63b2 100644 --- a/plugin/pkg/client/auth/oidc/oidc_test.go +++ b/plugin/pkg/client/auth/oidc/oidc_test.go @@ -1,14 +1,62 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package oidc import ( + "encoding/base64" + "io/ioutil" + "os" + "path" "testing" - "k8s.io/kubernetes/pkg/util/diff" - "github.com/coreos/go-oidc/jose" + + "k8s.io/kubernetes/pkg/util/diff" + oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing" ) func TestNewOIDCAuthProvider(t *testing.T) { + cert := path.Join(os.TempDir(), "oidc-cert") + key := path.Join(os.TempDir(), "oidc-key") + + defer os.Remove(cert) + defer os.Remove(key) + + oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key) + op := oidctesting.NewOIDCProvider(t) + srv, err := op.ServeTLSWithKeyPair(cert, key) + op.AddMinimalProviderConfig(srv) + if err != nil { + t.Fatalf("Cannot start server %v", err) + } + defer srv.Close() + + certData, err := ioutil.ReadFile(cert) + if err != nil { + t.Fatalf("Could not read cert bytes %v", err) + } + + jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ + "test": "jwt", + }), op.PrivKey.Signer()) + if err != nil { + t.Fatalf("Could not create signed JWT %v", err) + } + tests := []struct { cfg map[string]string @@ -16,10 +64,70 @@ func TestNewOIDCAuthProvider(t *testing.T) { wantInitialIDToken jose.JWT }{ { + // A Valid configuration cfg: map[string]string{ - cfgIssuerUrl: "auth.example.com", + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "client-secret", }, }, + { + // A Valid configuration with an Initial JWT + cfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "client-secret", + cfgIDToken: jwt.Encode(), + }, + wantInitialIDToken: *jwt, + }, + { + // Valid config, but using cfgCertificateAuthorityData + cfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData), + cfgClientID: "client-id", + cfgClientSecret: "client-secret", + }, + }, + { + // Missing client id + cfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientSecret: "client-secret", + }, + wantErr: true, + }, + { + // Missing client secret + cfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + }, + wantErr: true, + }, + { + // Missing issuer url. + cfg: map[string]string{ + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "secret", + }, + wantErr: true, + }, + { + // No TLS config + cfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgClientID: "client-id", + cfgClientSecret: "secret", + }, + wantErr: true, + }, } for i, tt := range tests { @@ -27,8 +135,8 @@ func TestNewOIDCAuthProvider(t *testing.T) { if tt.wantErr { if err == nil { t.Errorf("case %d: want non-nil err", i) - continue } + continue } if err != nil { From 94ffa344a84f720fa48190caeab103d9d41c5015 Mon Sep 17 00:00:00 2001 From: Bobby Rullo Date: Thu, 12 May 2016 17:14:05 -0700 Subject: [PATCH 5/6] OIDC authprovider more testable, and add backoff * Use an interface for OIDC Client, so that we're testing the behavior of the client, not the go-oidc package itself * add backoff and retry when server rejects token --- plugin/pkg/client/auth/oidc/oidc.go | 90 ++++++++++++++++++++++++----- 1 file changed, 77 insertions(+), 13 deletions(-) diff --git a/plugin/pkg/client/auth/oidc/oidc.go b/plugin/pkg/client/auth/oidc/oidc.go index bf37d569b8e..3ad279c106e 100644 --- a/plugin/pkg/client/auth/oidc/oidc.go +++ b/plugin/pkg/client/auth/oidc/oidc.go @@ -30,6 +30,7 @@ import ( "github.com/golang/glog" "k8s.io/kubernetes/pkg/client/restclient" + "k8s.io/kubernetes/pkg/util/wait" ) const ( @@ -43,6 +44,15 @@ const ( cfgRefreshToken = "refresh-token" ) +var ( + backoff = wait.Backoff{ + Duration: 1 * time.Second, + Factor: 2, + Jitter: .1, + Steps: 5, + } +) + func init() { if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil { glog.Fatalf("Failed to register oidc auth plugin: %v", err) @@ -100,7 +110,7 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A Secret: clientSecret, }, ProviderConfig: providerCfg, - Scope: scopes, + Scope: append(scopes, oidc.DefaultScope...), } client, err := oidc.NewClient(oidcCfg) @@ -108,6 +118,8 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A return nil, fmt.Errorf("error creating OIDC Client: %v", err) } + oClient := &oidcClient{client} + var initialIDToken jose.JWT if cfg[cfgIDToken] != "" { initialIDToken, err = jose.ParseJWT(cfg[cfgIDToken]) @@ -119,7 +131,7 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A return &oidcAuthProvider{ initialIDToken: initialIDToken, refresher: &idTokenRefresher{ - client: client, + client: oClient, cfg: cfg, persister: persister, }, @@ -137,16 +149,57 @@ func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper RoundTripper: rt, } at.SetJWT(g.initialIDToken) - return at + return &roundTripper{ + wrapped: at, + refresher: g.refresher, + } } func (g *oidcAuthProvider) Login() error { return errors.New("not yet implemented") } +type OIDCClient interface { + refreshToken(rt string) (oauth2.TokenResponse, error) + verifyJWT(jwt jose.JWT) error +} + +type roundTripper struct { + refresher *idTokenRefresher + wrapped *oidc.AuthenticatedTransport +} + +func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + var res *http.Response + var err error + firstTime := true + wait.ExponentialBackoff(backoff, func() (bool, error) { + if !firstTime { + var jwt jose.JWT + jwt, err = r.refresher.Refresh() + if err != nil { + return true, nil + } + r.wrapped.SetJWT(jwt) + } else { + firstTime = false + } + + res, err = r.wrapped.RoundTrip(req) + if err != nil { + return true, nil + } + if res.StatusCode == http.StatusUnauthorized { + return false, nil + } + return true, nil + }) + return res, err +} + type idTokenRefresher struct { cfg map[string]string - client *oidc.Client + client OIDCClient persister restclient.AuthProviderConfigPersister intialIDToken jose.JWT } @@ -177,16 +230,10 @@ func (r *idTokenRefresher) Refresh() (jose.JWT, error) { return jose.JWT{}, errors.New("No valid id-token, and cannot refresh without refresh-token") } - oac, err := r.client.OAuthClient() + tokens, err := r.client.refreshToken(rt) if err != nil { - return jose.JWT{}, err + return jose.JWT{}, fmt.Errorf("could not refresh token: %v", err) } - - tokens, err := oac.RequestToken(oauth2.GrantTypeRefreshToken, rt) - if err != nil { - return jose.JWT{}, err - } - jwt, err := jose.ParseJWT(tokens.IDToken) if err != nil { return jose.JWT{}, err @@ -202,5 +249,22 @@ func (r *idTokenRefresher) Refresh() (jose.JWT, error) { return jose.JWT{}, fmt.Errorf("could not perist new tokens: %v", err) } - return jwt, r.client.VerifyJWT(jwt) + return jwt, r.client.verifyJWT(jwt) +} + +type oidcClient struct { + client *oidc.Client +} + +func (o *oidcClient) refreshToken(rt string) (oauth2.TokenResponse, error) { + oac, err := o.client.OAuthClient() + if err != nil { + return oauth2.TokenResponse{}, err + } + + return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt) +} + +func (o *oidcClient) verifyJWT(jwt jose.JWT) error { + return o.client.VerifyJWT(jwt) } From f575f89cd7eb2ef13dce8c32a8ade79752c0be18 Mon Sep 17 00:00:00 2001 From: Bobby Rullo Date: Thu, 12 May 2016 17:17:33 -0700 Subject: [PATCH 6/6] add tests for the OIDC WrapTransport tests that tokens gets refreshed, passed along as bearers, etc. --- plugin/pkg/client/auth/oidc/oidc_test.go | 461 ++++++++++++++++++++++- 1 file changed, 460 insertions(+), 1 deletion(-) diff --git a/plugin/pkg/client/auth/oidc/oidc_test.go b/plugin/pkg/client/auth/oidc/oidc_test.go index 767781e63b2..30d8e4ab208 100644 --- a/plugin/pkg/client/auth/oidc/oidc_test.go +++ b/plugin/pkg/client/auth/oidc/oidc_test.go @@ -18,14 +18,23 @@ package oidc import ( "encoding/base64" + "errors" + "fmt" "io/ioutil" + "net/http" "os" "path" + "reflect" + "strings" "testing" + "time" "github.com/coreos/go-oidc/jose" + "github.com/coreos/go-oidc/key" + "github.com/coreos/go-oidc/oauth2" "k8s.io/kubernetes/pkg/util/diff" + "k8s.io/kubernetes/pkg/util/wait" oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing" ) @@ -156,6 +165,456 @@ func TestNewOIDCAuthProvider(t *testing.T) { } } +func TestWrapTranport(t *testing.T) { + oldBackoff := backoff + defer func() { + backoff = oldBackoff + }() + backoff = wait.Backoff{ + Duration: 1 * time.Nanosecond, + Steps: 3, + } + + privKey, err := key.GeneratePrivateKey() + if err != nil { + t.Fatalf("can't generate private key: %v", err) + } + + makeToken := func(s string, exp time.Time, count int) *jose.JWT { + jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ + "test": s, + "exp": exp.UTC().Unix(), + "count": count, + }), privKey.Signer()) + if err != nil { + t.Fatalf("Could not create signed JWT %v", err) + } + return jwt + } + + goodToken := makeToken("good", time.Now().Add(time.Hour), 0) + goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1) + expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0) + + str := func(s string) *string { + return &s + } + tests := []struct { + cfgIDToken *jose.JWT + cfgRefreshToken *string + + expectRequests []testRoundTrip + + expectRefreshes []testRefresh + + expectPersists []testPersist + + wantStatus int + wantErr bool + }{ + { + // Initial JWT is set, it is good, it is set as bearer. + cfgIDToken: goodToken, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: 200, + }, + }, + + wantStatus: 200, + }, + { + // Initial JWT is set, but it's expired, so it gets refreshed. + cfgIDToken: expiredToken, + cfgRefreshToken: str("rt1"), + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken.Encode(), + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: 200, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken.Encode(), + cfgRefreshToken: "rt1", + }, + }, + }, + + wantStatus: 200, + }, + { + // Initial JWT is set, but it's expired, so it gets refreshed - this + // time the refresh token itself is also refreshed + cfgIDToken: expiredToken, + cfgRefreshToken: str("rt1"), + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken.Encode(), + RefreshToken: "rt2", + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: 200, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken.Encode(), + cfgRefreshToken: "rt2", + }, + }, + }, + + wantStatus: 200, + }, + { + // Initial JWT is not set, so it gets refreshed. + cfgRefreshToken: str("rt1"), + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken.Encode(), + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: 200, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken.Encode(), + cfgRefreshToken: "rt1", + }, + }, + }, + + wantStatus: 200, + }, + { + // Expired token, but no refresh token. + cfgIDToken: expiredToken, + + wantErr: true, + }, + { + // Initial JWT is not set, so it gets refreshed, but the server + // rejects it when it is used, so it refreshes again, which + // succeeds. + cfgRefreshToken: str("rt1"), + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken.Encode(), + }, + }, + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken2.Encode(), + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: http.StatusUnauthorized, + }, + { + expectBearerToken: goodToken2.Encode(), + returnHTTPStatus: http.StatusOK, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken.Encode(), + cfgRefreshToken: "rt1", + }, + }, + { + cfg: map[string]string{ + cfgIDToken: goodToken2.Encode(), + cfgRefreshToken: "rt1", + }, + }, + }, + + wantStatus: 200, + }, + { + // Initial JWT is but the server rejects it when it is used, so it + // refreshes again, which succeeds. + cfgRefreshToken: str("rt1"), + cfgIDToken: goodToken, + + expectRefreshes: []testRefresh{ + { + expectRefreshToken: "rt1", + returnTokens: oauth2.TokenResponse{ + IDToken: goodToken2.Encode(), + }, + }, + }, + + expectRequests: []testRoundTrip{ + { + expectBearerToken: goodToken.Encode(), + returnHTTPStatus: http.StatusUnauthorized, + }, + { + expectBearerToken: goodToken2.Encode(), + returnHTTPStatus: http.StatusOK, + }, + }, + + expectPersists: []testPersist{ + { + cfg: map[string]string{ + cfgIDToken: goodToken2.Encode(), + cfgRefreshToken: "rt1", + }, + }, + }, + wantStatus: 200, + }, + } + + for i, tt := range tests { + client := &testOIDCClient{ + refreshes: tt.expectRefreshes, + } + + persister := &testPersister{ + tt.expectPersists, + } + + cfg := map[string]string{} + if tt.cfgIDToken != nil { + cfg[cfgIDToken] = tt.cfgIDToken.Encode() + } + + if tt.cfgRefreshToken != nil { + cfg[cfgRefreshToken] = *tt.cfgRefreshToken + } + + ap := &oidcAuthProvider{ + refresher: &idTokenRefresher{ + client: client, + cfg: cfg, + persister: persister, + }, + } + + if tt.cfgIDToken != nil { + ap.initialIDToken = *tt.cfgIDToken + } + + tstRT := &testRoundTripper{ + tt.expectRequests, + } + + rt := ap.WrapTransport(tstRT) + + req, err := http.NewRequest("GET", "http://cluster.example.com", nil) + if err != nil { + t.Errorf("case %d: unexpected error making request: %v", i, err) + } + + res, err := rt.RoundTrip(req) + if tt.wantErr { + if err == nil { + t.Errorf("case %d: Expected non-nil error", i) + } + } else if err != nil { + t.Errorf("case %d: unexpected error making round trip: %v", i, err) + + } else { + if res.StatusCode != tt.wantStatus { + t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode) + } + } + + if err = client.verify(); err != nil { + t.Errorf("case %d: %v", i, err) + } + + if err = persister.verify(); err != nil { + t.Errorf("case %d: %v", i, err) + } + + if err = tstRT.verify(); err != nil { + t.Errorf("case %d: %v", i, err) + continue + } + + } +} + +type testRoundTrip struct { + expectBearerToken string + returnHTTPStatus int +} + +type testRoundTripper struct { + trips []testRoundTrip +} + +func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if len(t.trips) == 0 { + return nil, errors.New("unexpected RoundTrip call") + } + + var trip testRoundTrip + trip, t.trips = t.trips[0], t.trips[1:] + + var bt string + var parts []string + auth := strings.TrimSpace(req.Header.Get("Authorization")) + if auth == "" { + goto Compare + } + + parts = strings.Split(auth, " ") + if len(parts) < 2 || strings.ToLower(parts[0]) != "bearer" { + goto Compare + } + + bt = parts[1] + +Compare: + if trip.expectBearerToken != bt { + return nil, fmt.Errorf("want bearerToken=%v, got=%v", trip.expectBearerToken, bt) + } + return &http.Response{ + StatusCode: trip.returnHTTPStatus, + }, nil +} + +func (t *testRoundTripper) verify() error { + if l := len(t.trips); l > 0 { + return fmt.Errorf("%d uncalled round trips", l) + } + return nil +} + +type testPersist struct { + cfg map[string]string + returnErr error +} + +type testPersister struct { + persists []testPersist +} + +func (t *testPersister) Persist(cfg map[string]string) error { + if len(t.persists) == 0 { + return errors.New("unexpected persist call") + } + + var persist testPersist + persist, t.persists = t.persists[0], t.persists[1:] + + if !reflect.DeepEqual(persist.cfg, cfg) { + return fmt.Errorf("Unexpected cfg: %v", diff.ObjectDiff(persist.cfg, cfg)) + } + + return persist.returnErr +} + +func (t *testPersister) verify() error { + if l := len(t.persists); l > 0 { + return fmt.Errorf("%d uncalled persists", l) + } + return nil +} + +type testRefresh struct { + expectRefreshToken string + + returnErr error + returnTokens oauth2.TokenResponse +} + +type testOIDCClient struct { + refreshes []testRefresh +} + +func (o *testOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) { + if len(o.refreshes) == 0 { + return oauth2.TokenResponse{}, errors.New("unexpected refresh request") + } + + var refresh testRefresh + refresh, o.refreshes = o.refreshes[0], o.refreshes[1:] + + if rt != refresh.expectRefreshToken { + return oauth2.TokenResponse{}, fmt.Errorf("want rt=%v, got=%v", + refresh.expectRefreshToken, + rt) + } + + if refresh.returnErr != nil { + return oauth2.TokenResponse{}, refresh.returnErr + } + + return refresh.returnTokens, nil +} + +func (o *testOIDCClient) verifyJWT(jwt jose.JWT) error { + claims, err := jwt.Claims() + if err != nil { + return err + } + claim, _, _ := claims.StringClaim("test") + if claim != "good" { + return errors.New("bad token") + } + return nil +} + +func (t *testOIDCClient) verify() error { + if l := len(t.refreshes); l > 0 { + return fmt.Errorf("%d uncalled refreshes", l) + } + return nil +} + func compareJWTs(a, b jose.JWT) string { if a.Encode() == b.Encode() { return "" @@ -179,5 +638,5 @@ func compareJWTs(a, b jose.JWT) string { } } - return diff.ObjectDiff(a, b) + return diff.ObjectDiff(aClaims, bClaims) }