diff --git a/plugin/pkg/client/auth/oidc/oidc.go b/plugin/pkg/client/auth/oidc/oidc.go index bf37d569b8e..3ad279c106e 100644 --- a/plugin/pkg/client/auth/oidc/oidc.go +++ b/plugin/pkg/client/auth/oidc/oidc.go @@ -30,6 +30,7 @@ import ( "github.com/golang/glog" "k8s.io/kubernetes/pkg/client/restclient" + "k8s.io/kubernetes/pkg/util/wait" ) const ( @@ -43,6 +44,15 @@ const ( cfgRefreshToken = "refresh-token" ) +var ( + backoff = wait.Backoff{ + Duration: 1 * time.Second, + Factor: 2, + Jitter: .1, + Steps: 5, + } +) + func init() { if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil { glog.Fatalf("Failed to register oidc auth plugin: %v", err) @@ -100,7 +110,7 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A Secret: clientSecret, }, ProviderConfig: providerCfg, - Scope: scopes, + Scope: append(scopes, oidc.DefaultScope...), } client, err := oidc.NewClient(oidcCfg) @@ -108,6 +118,8 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A return nil, fmt.Errorf("error creating OIDC Client: %v", err) } + oClient := &oidcClient{client} + var initialIDToken jose.JWT if cfg[cfgIDToken] != "" { initialIDToken, err = jose.ParseJWT(cfg[cfgIDToken]) @@ -119,7 +131,7 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A return &oidcAuthProvider{ initialIDToken: initialIDToken, refresher: &idTokenRefresher{ - client: client, + client: oClient, cfg: cfg, persister: persister, }, @@ -137,16 +149,57 @@ func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper RoundTripper: rt, } at.SetJWT(g.initialIDToken) - return at + return &roundTripper{ + wrapped: at, + refresher: g.refresher, + } } func (g *oidcAuthProvider) Login() error { return errors.New("not yet implemented") } +type OIDCClient interface { + refreshToken(rt string) (oauth2.TokenResponse, error) + verifyJWT(jwt jose.JWT) error +} + +type roundTripper struct { + refresher *idTokenRefresher + wrapped *oidc.AuthenticatedTransport +} + +func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + var res *http.Response + var err error + firstTime := true + wait.ExponentialBackoff(backoff, func() (bool, error) { + if !firstTime { + var jwt jose.JWT + jwt, err = r.refresher.Refresh() + if err != nil { + return true, nil + } + r.wrapped.SetJWT(jwt) + } else { + firstTime = false + } + + res, err = r.wrapped.RoundTrip(req) + if err != nil { + return true, nil + } + if res.StatusCode == http.StatusUnauthorized { + return false, nil + } + return true, nil + }) + return res, err +} + type idTokenRefresher struct { cfg map[string]string - client *oidc.Client + client OIDCClient persister restclient.AuthProviderConfigPersister intialIDToken jose.JWT } @@ -177,16 +230,10 @@ func (r *idTokenRefresher) Refresh() (jose.JWT, error) { return jose.JWT{}, errors.New("No valid id-token, and cannot refresh without refresh-token") } - oac, err := r.client.OAuthClient() + tokens, err := r.client.refreshToken(rt) if err != nil { - return jose.JWT{}, err + return jose.JWT{}, fmt.Errorf("could not refresh token: %v", err) } - - tokens, err := oac.RequestToken(oauth2.GrantTypeRefreshToken, rt) - if err != nil { - return jose.JWT{}, err - } - jwt, err := jose.ParseJWT(tokens.IDToken) if err != nil { return jose.JWT{}, err @@ -202,5 +249,22 @@ func (r *idTokenRefresher) Refresh() (jose.JWT, error) { return jose.JWT{}, fmt.Errorf("could not perist new tokens: %v", err) } - return jwt, r.client.VerifyJWT(jwt) + return jwt, r.client.verifyJWT(jwt) +} + +type oidcClient struct { + client *oidc.Client +} + +func (o *oidcClient) refreshToken(rt string) (oauth2.TokenResponse, error) { + oac, err := o.client.OAuthClient() + if err != nil { + return oauth2.TokenResponse{}, err + } + + return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt) +} + +func (o *oidcClient) verifyJWT(jwt jose.JWT) error { + return o.client.VerifyJWT(jwt) }