mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-02 16:29:21 +00:00
add tests for the OIDC WrapTransport
tests that tokens gets refreshed, passed along as bearers, etc.
This commit is contained in:
parent
94ffa344a8
commit
f575f89cd7
@ -18,14 +18,23 @@ package oidc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/jose"
|
"github.com/coreos/go-oidc/jose"
|
||||||
|
"github.com/coreos/go-oidc/key"
|
||||||
|
"github.com/coreos/go-oidc/oauth2"
|
||||||
|
|
||||||
"k8s.io/kubernetes/pkg/util/diff"
|
"k8s.io/kubernetes/pkg/util/diff"
|
||||||
|
"k8s.io/kubernetes/pkg/util/wait"
|
||||||
oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
|
oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -156,6 +165,456 @@ func TestNewOIDCAuthProvider(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWrapTranport(t *testing.T) {
|
||||||
|
oldBackoff := backoff
|
||||||
|
defer func() {
|
||||||
|
backoff = oldBackoff
|
||||||
|
}()
|
||||||
|
backoff = wait.Backoff{
|
||||||
|
Duration: 1 * time.Nanosecond,
|
||||||
|
Steps: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
privKey, err := key.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't generate private key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
makeToken := func(s string, exp time.Time, count int) *jose.JWT {
|
||||||
|
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||||
|
"test": s,
|
||||||
|
"exp": exp.UTC().Unix(),
|
||||||
|
"count": count,
|
||||||
|
}), privKey.Signer())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not create signed JWT %v", err)
|
||||||
|
}
|
||||||
|
return jwt
|
||||||
|
}
|
||||||
|
|
||||||
|
goodToken := makeToken("good", time.Now().Add(time.Hour), 0)
|
||||||
|
goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1)
|
||||||
|
expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0)
|
||||||
|
|
||||||
|
str := func(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
cfgIDToken *jose.JWT
|
||||||
|
cfgRefreshToken *string
|
||||||
|
|
||||||
|
expectRequests []testRoundTrip
|
||||||
|
|
||||||
|
expectRefreshes []testRefresh
|
||||||
|
|
||||||
|
expectPersists []testPersist
|
||||||
|
|
||||||
|
wantStatus int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// Initial JWT is set, it is good, it is set as bearer.
|
||||||
|
cfgIDToken: goodToken,
|
||||||
|
|
||||||
|
expectRequests: []testRoundTrip{
|
||||||
|
{
|
||||||
|
expectBearerToken: goodToken.Encode(),
|
||||||
|
returnHTTPStatus: 200,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
wantStatus: 200,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Initial JWT is set, but it's expired, so it gets refreshed.
|
||||||
|
cfgIDToken: expiredToken,
|
||||||
|
cfgRefreshToken: str("rt1"),
|
||||||
|
|
||||||
|
expectRefreshes: []testRefresh{
|
||||||
|
{
|
||||||
|
expectRefreshToken: "rt1",
|
||||||
|
returnTokens: oauth2.TokenResponse{
|
||||||
|
IDToken: goodToken.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectRequests: []testRoundTrip{
|
||||||
|
{
|
||||||
|
expectBearerToken: goodToken.Encode(),
|
||||||
|
returnHTTPStatus: 200,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectPersists: []testPersist{
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIDToken: goodToken.Encode(),
|
||||||
|
cfgRefreshToken: "rt1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
wantStatus: 200,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Initial JWT is set, but it's expired, so it gets refreshed - this
|
||||||
|
// time the refresh token itself is also refreshed
|
||||||
|
cfgIDToken: expiredToken,
|
||||||
|
cfgRefreshToken: str("rt1"),
|
||||||
|
|
||||||
|
expectRefreshes: []testRefresh{
|
||||||
|
{
|
||||||
|
expectRefreshToken: "rt1",
|
||||||
|
returnTokens: oauth2.TokenResponse{
|
||||||
|
IDToken: goodToken.Encode(),
|
||||||
|
RefreshToken: "rt2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectRequests: []testRoundTrip{
|
||||||
|
{
|
||||||
|
expectBearerToken: goodToken.Encode(),
|
||||||
|
returnHTTPStatus: 200,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectPersists: []testPersist{
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIDToken: goodToken.Encode(),
|
||||||
|
cfgRefreshToken: "rt2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
wantStatus: 200,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Initial JWT is not set, so it gets refreshed.
|
||||||
|
cfgRefreshToken: str("rt1"),
|
||||||
|
|
||||||
|
expectRefreshes: []testRefresh{
|
||||||
|
{
|
||||||
|
expectRefreshToken: "rt1",
|
||||||
|
returnTokens: oauth2.TokenResponse{
|
||||||
|
IDToken: goodToken.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectRequests: []testRoundTrip{
|
||||||
|
{
|
||||||
|
expectBearerToken: goodToken.Encode(),
|
||||||
|
returnHTTPStatus: 200,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectPersists: []testPersist{
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIDToken: goodToken.Encode(),
|
||||||
|
cfgRefreshToken: "rt1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
wantStatus: 200,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Expired token, but no refresh token.
|
||||||
|
cfgIDToken: expiredToken,
|
||||||
|
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Initial JWT is not set, so it gets refreshed, but the server
|
||||||
|
// rejects it when it is used, so it refreshes again, which
|
||||||
|
// succeeds.
|
||||||
|
cfgRefreshToken: str("rt1"),
|
||||||
|
|
||||||
|
expectRefreshes: []testRefresh{
|
||||||
|
{
|
||||||
|
expectRefreshToken: "rt1",
|
||||||
|
returnTokens: oauth2.TokenResponse{
|
||||||
|
IDToken: goodToken.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
expectRefreshToken: "rt1",
|
||||||
|
returnTokens: oauth2.TokenResponse{
|
||||||
|
IDToken: goodToken2.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectRequests: []testRoundTrip{
|
||||||
|
{
|
||||||
|
expectBearerToken: goodToken.Encode(),
|
||||||
|
returnHTTPStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
expectBearerToken: goodToken2.Encode(),
|
||||||
|
returnHTTPStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectPersists: []testPersist{
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIDToken: goodToken.Encode(),
|
||||||
|
cfgRefreshToken: "rt1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIDToken: goodToken2.Encode(),
|
||||||
|
cfgRefreshToken: "rt1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
wantStatus: 200,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Initial JWT is but the server rejects it when it is used, so it
|
||||||
|
// refreshes again, which succeeds.
|
||||||
|
cfgRefreshToken: str("rt1"),
|
||||||
|
cfgIDToken: goodToken,
|
||||||
|
|
||||||
|
expectRefreshes: []testRefresh{
|
||||||
|
{
|
||||||
|
expectRefreshToken: "rt1",
|
||||||
|
returnTokens: oauth2.TokenResponse{
|
||||||
|
IDToken: goodToken2.Encode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectRequests: []testRoundTrip{
|
||||||
|
{
|
||||||
|
expectBearerToken: goodToken.Encode(),
|
||||||
|
returnHTTPStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
expectBearerToken: goodToken2.Encode(),
|
||||||
|
returnHTTPStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
expectPersists: []testPersist{
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgIDToken: goodToken2.Encode(),
|
||||||
|
cfgRefreshToken: "rt1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStatus: 200,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
client := &testOIDCClient{
|
||||||
|
refreshes: tt.expectRefreshes,
|
||||||
|
}
|
||||||
|
|
||||||
|
persister := &testPersister{
|
||||||
|
tt.expectPersists,
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := map[string]string{}
|
||||||
|
if tt.cfgIDToken != nil {
|
||||||
|
cfg[cfgIDToken] = tt.cfgIDToken.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.cfgRefreshToken != nil {
|
||||||
|
cfg[cfgRefreshToken] = *tt.cfgRefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
ap := &oidcAuthProvider{
|
||||||
|
refresher: &idTokenRefresher{
|
||||||
|
client: client,
|
||||||
|
cfg: cfg,
|
||||||
|
persister: persister,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.cfgIDToken != nil {
|
||||||
|
ap.initialIDToken = *tt.cfgIDToken
|
||||||
|
}
|
||||||
|
|
||||||
|
tstRT := &testRoundTripper{
|
||||||
|
tt.expectRequests,
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := ap.WrapTransport(tstRT)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", "http://cluster.example.com", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("case %d: unexpected error making request: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := rt.RoundTrip(req)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("case %d: Expected non-nil error", i)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
t.Errorf("case %d: unexpected error making round trip: %v", i, err)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
if res.StatusCode != tt.wantStatus {
|
||||||
|
t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = client.verify(); err != nil {
|
||||||
|
t.Errorf("case %d: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = persister.verify(); err != nil {
|
||||||
|
t.Errorf("case %d: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tstRT.verify(); err != nil {
|
||||||
|
t.Errorf("case %d: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRoundTrip struct {
|
||||||
|
expectBearerToken string
|
||||||
|
returnHTTPStatus int
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRoundTripper struct {
|
||||||
|
trips []testRoundTrip
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if len(t.trips) == 0 {
|
||||||
|
return nil, errors.New("unexpected RoundTrip call")
|
||||||
|
}
|
||||||
|
|
||||||
|
var trip testRoundTrip
|
||||||
|
trip, t.trips = t.trips[0], t.trips[1:]
|
||||||
|
|
||||||
|
var bt string
|
||||||
|
var parts []string
|
||||||
|
auth := strings.TrimSpace(req.Header.Get("Authorization"))
|
||||||
|
if auth == "" {
|
||||||
|
goto Compare
|
||||||
|
}
|
||||||
|
|
||||||
|
parts = strings.Split(auth, " ")
|
||||||
|
if len(parts) < 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||||
|
goto Compare
|
||||||
|
}
|
||||||
|
|
||||||
|
bt = parts[1]
|
||||||
|
|
||||||
|
Compare:
|
||||||
|
if trip.expectBearerToken != bt {
|
||||||
|
return nil, fmt.Errorf("want bearerToken=%v, got=%v", trip.expectBearerToken, bt)
|
||||||
|
}
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: trip.returnHTTPStatus,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoundTripper) verify() error {
|
||||||
|
if l := len(t.trips); l > 0 {
|
||||||
|
return fmt.Errorf("%d uncalled round trips", l)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testPersist struct {
|
||||||
|
cfg map[string]string
|
||||||
|
returnErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
type testPersister struct {
|
||||||
|
persists []testPersist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testPersister) Persist(cfg map[string]string) error {
|
||||||
|
if len(t.persists) == 0 {
|
||||||
|
return errors.New("unexpected persist call")
|
||||||
|
}
|
||||||
|
|
||||||
|
var persist testPersist
|
||||||
|
persist, t.persists = t.persists[0], t.persists[1:]
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(persist.cfg, cfg) {
|
||||||
|
return fmt.Errorf("Unexpected cfg: %v", diff.ObjectDiff(persist.cfg, cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
return persist.returnErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testPersister) verify() error {
|
||||||
|
if l := len(t.persists); l > 0 {
|
||||||
|
return fmt.Errorf("%d uncalled persists", l)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRefresh struct {
|
||||||
|
expectRefreshToken string
|
||||||
|
|
||||||
|
returnErr error
|
||||||
|
returnTokens oauth2.TokenResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
type testOIDCClient struct {
|
||||||
|
refreshes []testRefresh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *testOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||||
|
if len(o.refreshes) == 0 {
|
||||||
|
return oauth2.TokenResponse{}, errors.New("unexpected refresh request")
|
||||||
|
}
|
||||||
|
|
||||||
|
var refresh testRefresh
|
||||||
|
refresh, o.refreshes = o.refreshes[0], o.refreshes[1:]
|
||||||
|
|
||||||
|
if rt != refresh.expectRefreshToken {
|
||||||
|
return oauth2.TokenResponse{}, fmt.Errorf("want rt=%v, got=%v",
|
||||||
|
refresh.expectRefreshToken,
|
||||||
|
rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if refresh.returnErr != nil {
|
||||||
|
return oauth2.TokenResponse{}, refresh.returnErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return refresh.returnTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *testOIDCClient) verifyJWT(jwt jose.JWT) error {
|
||||||
|
claims, err := jwt.Claims()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
claim, _, _ := claims.StringClaim("test")
|
||||||
|
if claim != "good" {
|
||||||
|
return errors.New("bad token")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testOIDCClient) verify() error {
|
||||||
|
if l := len(t.refreshes); l > 0 {
|
||||||
|
return fmt.Errorf("%d uncalled refreshes", l)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func compareJWTs(a, b jose.JWT) string {
|
func compareJWTs(a, b jose.JWT) string {
|
||||||
if a.Encode() == b.Encode() {
|
if a.Encode() == b.Encode() {
|
||||||
return ""
|
return ""
|
||||||
@ -179,5 +638,5 @@ func compareJWTs(a, b jose.JWT) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return diff.ObjectDiff(a, b)
|
return diff.ObjectDiff(aClaims, bClaims)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user