exec auth: support TLS config caching

This change updates the transport.Config .Dial and .TLS.GetCert fields
to use a struct wrapper.  This indirection via a pointer allows the
functions to be compared and thus makes them valid to use as map keys.
This change is then leveraged by the existing global exec auth and TLS
config caches to return the same authenticator and TLS config even when
distinct but identical rest configs were used to create distinct
clientsets.

Signed-off-by: Monis Khan <mok@microsoft.com>

Kubernetes-commit: 831d95b6a021c2767effe85e461309cb6a0fdcec
This commit is contained in:
Monis Khan 2022-08-24 16:04:19 +00:00 committed by Kubernetes Publisher
parent 2698e8276e
commit 6a008ec216
7 changed files with 354 additions and 15 deletions

View File

@ -199,14 +199,18 @@ func newAuthenticator(c *cache, isTerminalFunc func(int) bool, config *api.ExecC
now: time.Now, now: time.Now,
environ: os.Environ, environ: os.Environ,
defaultDialer: defaultDialer, connTracker: connTracker,
connTracker: connTracker,
} }
for _, env := range config.Env { for _, env := range config.Env {
a.env = append(a.env, env.Name+"="+env.Value) a.env = append(a.env, env.Name+"="+env.Value)
} }
// these functions are made comparable and stored in the cache so that repeated clientset
// construction with the same rest.Config results in a single TLS cache and Authenticator
a.getCert = &transport.GetCertHolder{GetCert: a.cert}
a.dial = &transport.DialHolder{Dial: defaultDialer.DialContext}
return c.put(key, a), nil return c.put(key, a), nil
} }
@ -261,8 +265,6 @@ type Authenticator struct {
now func() time.Time now func() time.Time
environ func() []string environ func() []string
// defaultDialer is used for clients which don't specify a custom dialer
defaultDialer *connrotation.Dialer
// connTracker tracks all connections opened that we need to close when rotating a client certificate // connTracker tracks all connections opened that we need to close when rotating a client certificate
connTracker *connrotation.ConnectionTracker connTracker *connrotation.ConnectionTracker
@ -273,6 +275,12 @@ type Authenticator struct {
mu sync.Mutex mu sync.Mutex
cachedCreds *credentials cachedCreds *credentials
exp time.Time exp time.Time
// getCert makes Authenticator.cert comparable to support TLS config caching
getCert *transport.GetCertHolder
// dial is used for clients which do not specify a custom dialer
// it is comparable to support TLS config caching
dial *transport.DialHolder
} }
type credentials struct { type credentials struct {
@ -300,18 +308,20 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
if c.HasCertCallback() { if c.HasCertCallback() {
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set") return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
} }
c.TLS.GetCert = a.cert c.TLS.GetCert = a.getCert.GetCert
c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching
var d *connrotation.Dialer
if c.Dial != nil { if c.Dial != nil {
// if c has a custom dialer, we have to wrap it // if c has a custom dialer, we have to wrap it
d = connrotation.NewDialerWithTracker(c.Dial, a.connTracker) // TLS config caching is not supported for this config
d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
c.Dial = d.DialContext
c.DialHolder = nil
} else { } else {
d = a.defaultDialer c.Dial = a.dial.Dial
c.DialHolder = a.dial // comparable for TLS config caching
} }
c.Dial = d.DialContext
return nil return nil
} }

View File

@ -0,0 +1,106 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package exec_test // separate package to prevent circular import
import (
"context"
"testing"
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilnet "k8s.io/apimachinery/pkg/util/net"
clientset "k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
)
// TestExecTLSCache asserts the semantics of the TLS cache when exec auth is used.
//
// In particular, when:
// - multiple identical rest configs exist as distinct objects, and
// - these rest configs use exec auth, and
// - these rest configs are used to create distinct clientsets, then
//
// the underlying TLS config is shared between those clientsets.
func TestExecTLSCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
t.Cleanup(cancel)
config1 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client1 := clientset.NewForConfigOrDie(config1)
config2 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client2 := clientset.NewForConfigOrDie(config2)
config3 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
Args: []string{"make this exec auth different"},
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client3 := clientset.NewForConfigOrDie(config3)
_, _ = client1.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
_, _ = client2.CoreV1().Namespaces().List(ctx, metav1.ListOptions{})
_, _ = client3.CoreV1().PersistentVolumes().List(ctx, metav1.ListOptions{})
rt1 := client1.RESTClient().(*rest.RESTClient).Client.Transport
rt2 := client2.RESTClient().(*rest.RESTClient).Client.Transport
rt3 := client3.RESTClient().(*rest.RESTClient).Client.Transport
tlsConfig1, err := utilnet.TLSClientConfig(rt1)
if err != nil {
t.Fatal(err)
}
tlsConfig2, err := utilnet.TLSClientConfig(rt2)
if err != nil {
t.Fatal(err)
}
tlsConfig3, err := utilnet.TLSClientConfig(rt3)
if err != nil {
t.Fatal(err)
}
if tlsConfig1 == nil || tlsConfig2 == nil || tlsConfig3 == nil {
t.Fatal("expected non-nil TLS configs")
}
if tlsConfig1 != tlsConfig2 {
t.Fatal("expected the same TLS config for matching exec config via rest config")
}
if tlsConfig1 == tlsConfig3 {
t.Fatal("expected different TLS config for non-matching exec config via rest config")
}
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport package transport
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -55,6 +56,9 @@ type tlsCacheKey struct {
serverName string serverName string
nextProtos string nextProtos string
disableCompression bool disableCompression bool
// these functions are wrapped to allow them to be used as map keys
getCert *GetCertHolder
dial *DialHolder
} }
func (t tlsCacheKey) String() string { func (t tlsCacheKey) String() string {
@ -62,7 +66,8 @@ func (t tlsCacheKey) String() string {
if len(t.keyData) > 0 { if len(t.keyData) > 0 {
keyText = "<redacted>" keyText = "<redacted>"
} }
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression) return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
} }
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
@ -92,8 +97,10 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
return http.DefaultTransport, nil return http.DefaultTransport, nil
} }
dial := config.Dial var dial func(ctx context.Context, network, address string) (net.Conn, error)
if dial == nil { if config.Dial != nil {
dial = config.Dial
} else {
dial = (&net.Dialer{ dial = (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
@ -138,10 +145,18 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
return tlsCacheKey{}, false, err return tlsCacheKey{}, false, err
} }
if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil { if c.Proxy != nil {
// cannot determine equality for functions // cannot determine equality for functions
return tlsCacheKey{}, false, nil return tlsCacheKey{}, false, nil
} }
if c.Dial != nil && c.DialHolder == nil {
// cannot determine equality for dial function that doesn't have non-nil DialHolder set as well
return tlsCacheKey{}, false, nil
}
if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil {
// cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well
return tlsCacheKey{}, false, nil
}
k := tlsCacheKey{ k := tlsCacheKey{
insecure: c.TLS.Insecure, insecure: c.TLS.Insecure,
@ -149,6 +164,8 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
serverName: c.TLS.ServerName, serverName: c.TLS.ServerName,
nextProtos: strings.Join(c.TLS.NextProtos, ","), nextProtos: strings.Join(c.TLS.NextProtos, ","),
disableCompression: c.DisableCompression, disableCompression: c.DisableCompression,
getCert: c.TLS.GetCertHolder,
dial: c.DialHolder,
} }
if c.TLS.ReloadTLSFiles { if c.TLS.ReloadTLSFiles {

View File

@ -21,6 +21,7 @@ import (
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
"net/url"
"testing" "testing"
) )
@ -58,16 +59,24 @@ func TestTLSConfigKey(t *testing.T) {
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue continue
} }
if keyA != (tlsCacheKey{}) {
t.Errorf("Expected empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
} }
} }
// Make sure config fields that affect the tls config affect the cache key // Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{} dialer := net.Dialer{}
getCert := func() (*tls.Certificate, error) { return nil, nil } getCert := func() (*tls.Certificate, error) { return nil, nil }
getCertHolder := &GetCertHolder{GetCert: getCert}
uniqueConfigurations := map[string]*Config{ uniqueConfigurations := map[string]*Config{
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
"no tls": {}, "no tls": {},
"dialer": {Dial: dialer.DialContext}, "dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}, "dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}},
"dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
"insecure": {TLS: TLSConfig{Insecure: true}}, "insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
@ -128,6 +137,13 @@ func TestTLSConfigKey(t *testing.T) {
GetCert: func() (*tls.Certificate, error) { return nil, nil }, GetCert: func() (*tls.Certificate, error) { return nil, nil },
}, },
}, },
"getCert3": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: getCert,
GetCertHolder: getCertHolder,
},
},
"getCert1, key 2": { "getCert1, key 2": {
TLS: TLSConfig{ TLS: TLSConfig{
KeyData: []byte{2}, KeyData: []byte{2},

View File

@ -68,7 +68,11 @@ type Config struct {
WrapTransport WrapperFunc WrapTransport WrapperFunc
// Dial specifies the dial function for creating unencrypted TCP connections. // Dial specifies the dial function for creating unencrypted TCP connections.
// If specified, this transport will be non-cacheable unless DialHolder is also set.
Dial func(ctx context.Context, network, address string) (net.Conn, error) Dial func(ctx context.Context, network, address string) (net.Conn, error)
// DialHolder can be populated to make transport configs cacheable.
// If specified, DialHolder.Dial must be equal to Dial.
DialHolder *DialHolder
// Proxy is the proxy func to be used for all requests made by this // Proxy is the proxy func to be used for all requests made by this
// transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy // transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy
@ -78,6 +82,11 @@ type Config struct {
Proxy func(*http.Request) (*url.URL, error) Proxy func(*http.Request) (*url.URL, error)
} }
// DialHolder is used to make the wrapped function comparable so that it can be used as a map key.
type DialHolder struct {
Dial func(ctx context.Context, network, address string) (net.Conn, error)
}
// ImpersonationConfig has all the available impersonation options // ImpersonationConfig has all the available impersonation options
type ImpersonationConfig struct { type ImpersonationConfig struct {
// UserName matches user.Info.GetName() // UserName matches user.Info.GetName()
@ -143,5 +152,15 @@ type TLSConfig struct {
// To use only http/1.1, set to ["http/1.1"]. // To use only http/1.1, set to ["http/1.1"].
NextProtos []string NextProtos []string
GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field. // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
// If specified, this transport is non-cacheable unless CertHolder is populated.
GetCert func() (*tls.Certificate, error)
// CertHolder can be populated to make transport configs that set GetCert cacheable.
// If set, CertHolder.GetCert must be equal to GetCert.
GetCertHolder *GetCertHolder
}
// GetCertHolder is used to make the wrapped function comparable so that it can be used as a map key.
type GetCertHolder struct {
GetCert func() (*tls.Certificate, error)
} }

View File

@ -24,6 +24,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"reflect"
"sync" "sync"
"time" "time"
@ -39,6 +40,10 @@ func New(config *Config) (http.RoundTripper, error) {
return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed") return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
} }
if !isValidHolders(config) {
return nil, fmt.Errorf("misconfigured holder for dialer or cert callback")
}
var ( var (
rt http.RoundTripper rt http.RoundTripper
err error err error
@ -56,6 +61,26 @@ func New(config *Config) (http.RoundTripper, error) {
return HTTPWrappersForConfig(config, rt) return HTTPWrappersForConfig(config, rt)
} }
func isValidHolders(config *Config) bool {
if config.TLS.GetCertHolder != nil {
if config.TLS.GetCertHolder.GetCert == nil ||
config.TLS.GetCert == nil ||
reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() {
return false
}
}
if config.DialHolder != nil {
if config.DialHolder.Dial == nil ||
config.Dial == nil ||
reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() {
return false
}
}
return true
}
// TLSConfigFor returns a tls.Config that will provide the transport level security defined // TLSConfigFor returns a tls.Config that will provide the transport level security defined
// by the provided Config. Will return nil if no transport level security is requested. // by the provided Config. Will return nil if no transport level security is requested.
func TLSConfigFor(c *Config) (*tls.Config, error) { func TLSConfigFor(c *Config) (*tls.Config, error) {

View File

@ -21,6 +21,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"testing" "testing"
) )
@ -94,6 +95,13 @@ stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
globalGetCert := &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
}
globalDial := &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
}
testCases := map[string]struct { testCases := map[string]struct {
Config *Config Config *Config
Err bool Err bool
@ -255,6 +263,144 @@ func TestNew(t *testing.T) {
}, },
}, },
}, },
"nil holders and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
TLS: false,
TLSCert: false,
TLSErr: false,
Default: true,
Insecure: false,
DefaultRoots: false,
},
"nil holders and non-nil regular get cert": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: true,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"nil holders and non-nil regular dial": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: false,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"non-nil dial holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
},
},
Err: true,
},
"non-nil cert holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil holders+internal and non-nil regular with correct address": {
Config: &Config{
TLS: TLSConfig{
GetCert: globalGetCert.GetCert,
GetCertHolder: globalGetCert,
},
Dial: globalDial.Dial,
DialHolder: globalDial,
},
Err: false,
TLS: true,
TLSCert: true,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
} }
for k, testCase := range testCases { for k, testCase := range testCases {
t.Run(k, func(t *testing.T) { t.Run(k, func(t *testing.T) {