mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-06 02:34:03 +00:00
add tests for newOIDCAuthProvider
This commit is contained in:
parent
c990462d0f
commit
e85940ed17
@ -19,7 +19,6 @@ package oidc
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
@ -34,14 +33,6 @@ import (
|
||||
oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
|
||||
)
|
||||
|
||||
func mustParseURL(t *testing.T, s string) *url.URL {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse url: %v", err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func generateToken(t *testing.T, op *oidctesting.OIDCProvider, iss, sub, aud string, usernameClaim, value, groupsClaim string, groups []string, iat, exp time.Time) string {
|
||||
signer := op.PrivKey.Signer()
|
||||
claims := oidc.NewClaims(iss, sub, aud, iat, exp)
|
||||
@ -98,7 +89,7 @@ func TestOIDCDiscoveryNoKeyEndpoint(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
op.PCFG = oidc.ProviderConfig{
|
||||
Issuer: mustParseURL(t, srv.URL), // An invalid ProviderConfig. Keys endpoint is required.
|
||||
Issuer: oidctesting.MustParseURL(srv.URL), // An invalid ProviderConfig. Keys endpoint is required.
|
||||
}
|
||||
|
||||
_, err = New(OIDCOptions{srv.URL, "client-foo", cert, "sub", "", 0, 0})
|
||||
@ -114,8 +105,8 @@ func TestOIDCDiscoverySecureConnection(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
op.PCFG = oidc.ProviderConfig{
|
||||
Issuer: mustParseURL(t, srv.URL),
|
||||
KeysEndpoint: mustParseURL(t, srv.URL+"/keys"),
|
||||
Issuer: oidctesting.MustParseURL(srv.URL),
|
||||
KeysEndpoint: oidctesting.MustParseURL(srv.URL + "/keys"),
|
||||
}
|
||||
|
||||
expectErr := fmt.Errorf("'oidc-issuer-url' (%q) has invalid scheme (%q), require 'https'", srv.URL, "http")
|
||||
@ -147,8 +138,8 @@ func TestOIDCDiscoverySecureConnection(t *testing.T) {
|
||||
defer tlsSrv.Close()
|
||||
|
||||
op.PCFG = oidc.ProviderConfig{
|
||||
Issuer: mustParseURL(t, tlsSrv.URL),
|
||||
KeysEndpoint: mustParseURL(t, tlsSrv.URL+"/keys"),
|
||||
Issuer: oidctesting.MustParseURL(tlsSrv.URL),
|
||||
KeysEndpoint: oidctesting.MustParseURL(tlsSrv.URL + "/keys"),
|
||||
}
|
||||
|
||||
// Create a client using cert2, should fail.
|
||||
@ -179,15 +170,7 @@ func TestOIDCAuthentication(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
// A provider config with all required fields.
|
||||
op.PCFG = oidc.ProviderConfig{
|
||||
Issuer: mustParseURL(t, srv.URL),
|
||||
AuthEndpoint: mustParseURL(t, srv.URL+"/auth"),
|
||||
TokenEndpoint: mustParseURL(t, srv.URL+"/token"),
|
||||
KeysEndpoint: mustParseURL(t, srv.URL+"/keys"),
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
SubjectTypesSupported: []string{"public"},
|
||||
IDTokenSigningAlgValues: []string{"RS256"},
|
||||
}
|
||||
op.AddMinimalProviderConfig(srv)
|
||||
|
||||
tests := []struct {
|
||||
userClaim string
|
||||
|
@ -31,6 +31,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@ -79,6 +80,19 @@ func (op *OIDCProvider) ServeTLSWithKeyPair(cert, key string) (*httptest.Server,
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
func (op *OIDCProvider) AddMinimalProviderConfig(srv *httptest.Server) {
|
||||
op.PCFG = oidc.ProviderConfig{
|
||||
Issuer: MustParseURL(srv.URL),
|
||||
AuthEndpoint: MustParseURL(srv.URL + "/auth"),
|
||||
TokenEndpoint: MustParseURL(srv.URL + "/token"),
|
||||
KeysEndpoint: MustParseURL(srv.URL + "/keys"),
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
SubjectTypesSupported: []string{"public"},
|
||||
IDTokenSigningAlgValues: []string{"RS256"},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (op *OIDCProvider) handleConfig(w http.ResponseWriter, req *http.Request) {
|
||||
b, err := json.Marshal(&op.PCFG)
|
||||
if err != nil {
|
||||
@ -108,6 +122,14 @@ func (op *OIDCProvider) handleKeys(w http.ResponseWriter, req *http.Request) {
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func MustParseURL(s string) *url.URL {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Failed to parse url: %v", err))
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// generateSelfSignedCert generates a self-signed cert/key pairs and writes to the certPath/keyPath.
|
||||
// This method is mostly identical to crypto.GenerateSelfSignedCert except for the 'IsCA' and 'KeyUsage'
|
||||
// in the certificate template. (Maybe we can merge these two methods).
|
||||
|
@ -34,12 +34,12 @@ import (
|
||||
|
||||
const (
|
||||
cfgIssuerUrl = "idp-issuer-url"
|
||||
cfgClientId = "client-id"
|
||||
cfgClientID = "client-id"
|
||||
cfgClientSecret = "client-secret"
|
||||
cfgCertificateAuthority = "idp-certificate-authority"
|
||||
cfgCertificateAuthorityData = "idp-certificate-authority-data"
|
||||
cfgExtraScopes = "extra-scopes"
|
||||
cfgIdToken = "id-token"
|
||||
cfgIDToken = "id-token"
|
||||
cfgRefreshToken = "refresh-token"
|
||||
)
|
||||
|
||||
@ -50,6 +50,21 @@ func init() {
|
||||
}
|
||||
|
||||
func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
|
||||
issuer := cfg[cfgIssuerUrl]
|
||||
if issuer == "" {
|
||||
return nil, fmt.Errorf("Must provide %s", cfgIssuerUrl)
|
||||
}
|
||||
|
||||
clientID := cfg[cfgClientID]
|
||||
if clientID == "" {
|
||||
return nil, fmt.Errorf("Must provide %s", cfgClientID)
|
||||
}
|
||||
|
||||
clientSecret := cfg[cfgClientSecret]
|
||||
if clientSecret == "" {
|
||||
return nil, fmt.Errorf("Must provide %s", cfgClientSecret)
|
||||
}
|
||||
|
||||
var certAuthData []byte
|
||||
var err error
|
||||
if cfg[cfgCertificateAuthorityData] != "" {
|
||||
@ -72,11 +87,6 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
|
||||
}
|
||||
hc := &http.Client{Transport: trans}
|
||||
|
||||
issuer, ok := cfg[cfgIssuerUrl]
|
||||
if !ok || issuer == "" {
|
||||
return nil, errors.New("Must provide idp-issuer-url")
|
||||
}
|
||||
|
||||
providerCfg, err := oidc.FetchProviderConfig(hc, strings.TrimSuffix(issuer, "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching provider config: %v", err)
|
||||
@ -86,8 +96,8 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
|
||||
oidcCfg := oidc.ClientConfig{
|
||||
HTTPClient: hc,
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: cfg[cfgClientId],
|
||||
Secret: cfg[cfgClientSecret],
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
},
|
||||
ProviderConfig: providerCfg,
|
||||
Scope: scopes,
|
||||
@ -99,15 +109,15 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
|
||||
}
|
||||
|
||||
var initialIDToken jose.JWT
|
||||
if cfg[cfgIdToken] != "" {
|
||||
initialIDToken, err = jose.ParseJWT(cfg[cfgIdToken])
|
||||
if cfg[cfgIDToken] != "" {
|
||||
initialIDToken, err = jose.ParseJWT(cfg[cfgIDToken])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &oidcAuthProvider{
|
||||
intialIDToken: initialIDToken,
|
||||
initialIDToken: initialIDToken,
|
||||
refresher: &idTokenRefresher{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
@ -117,8 +127,8 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
|
||||
}
|
||||
|
||||
type oidcAuthProvider struct {
|
||||
refresher *idTokenRefresher
|
||||
intialIDToken jose.JWT
|
||||
refresher *idTokenRefresher
|
||||
initialIDToken jose.JWT
|
||||
}
|
||||
|
||||
func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
|
||||
@ -126,7 +136,7 @@ func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper
|
||||
TokenRefresher: g.refresher,
|
||||
RoundTripper: rt,
|
||||
}
|
||||
at.SetJWT(g.intialIDToken)
|
||||
at.SetJWT(g.initialIDToken)
|
||||
return at
|
||||
}
|
||||
|
||||
@ -185,7 +195,7 @@ func (r *idTokenRefresher) Refresh() (jose.JWT, error) {
|
||||
if tokens.RefreshToken != "" && tokens.RefreshToken != rt {
|
||||
r.cfg[cfgRefreshToken] = tokens.RefreshToken
|
||||
}
|
||||
r.cfg[cfgIdToken] = jwt.Encode()
|
||||
r.cfg[cfgIDToken] = jwt.Encode()
|
||||
|
||||
err = r.persister.Persist(r.cfg)
|
||||
if err != nil {
|
||||
|
@ -1,14 +1,62 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"k8s.io/kubernetes/pkg/util/diff"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
|
||||
"k8s.io/kubernetes/pkg/util/diff"
|
||||
oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
|
||||
)
|
||||
|
||||
func TestNewOIDCAuthProvider(t *testing.T) {
|
||||
cert := path.Join(os.TempDir(), "oidc-cert")
|
||||
key := path.Join(os.TempDir(), "oidc-key")
|
||||
|
||||
defer os.Remove(cert)
|
||||
defer os.Remove(key)
|
||||
|
||||
oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key)
|
||||
op := oidctesting.NewOIDCProvider(t)
|
||||
srv, err := op.ServeTLSWithKeyPair(cert, key)
|
||||
op.AddMinimalProviderConfig(srv)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot start server %v", err)
|
||||
}
|
||||
defer srv.Close()
|
||||
|
||||
certData, err := ioutil.ReadFile(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not read cert bytes %v", err)
|
||||
}
|
||||
|
||||
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||
"test": "jwt",
|
||||
}), op.PrivKey.Signer())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create signed JWT %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
cfg map[string]string
|
||||
|
||||
@ -16,10 +64,70 @@ func TestNewOIDCAuthProvider(t *testing.T) {
|
||||
wantInitialIDToken jose.JWT
|
||||
}{
|
||||
{
|
||||
// A Valid configuration
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: "auth.example.com",
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
},
|
||||
{
|
||||
// A Valid configuration with an Initial JWT
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgIDToken: jwt.Encode(),
|
||||
},
|
||||
wantInitialIDToken: *jwt,
|
||||
},
|
||||
{
|
||||
// Valid config, but using cfgCertificateAuthorityData
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData),
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Missing client id
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
// Missing client secret
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
// Missing issuer url.
|
||||
cfg: map[string]string{
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "secret",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
// No TLS config
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "secret",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
@ -27,8 +135,8 @@ func TestNewOIDCAuthProvider(t *testing.T) {
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user