exec auth: support TLS config caching

This change updates the transport.Config .Dial and .TLS.GetCert fields
to use a struct wrapper.  This indirection via a pointer allows the
functions to be compared and thus makes them valid to use as map keys.
This change is then leveraged by the existing global exec auth and TLS
config caches to return the same authenticator and TLS config even when
distinct but identical rest configs were used to create distinct
clientsets.

Signed-off-by: Monis Khan <mok@microsoft.com>

Kubernetes-commit: 831d95b6a021c2767effe85e461309cb6a0fdcec
This commit is contained in:
Monis Khan 2022-08-24 16:04:19 +00:00 committed by Kubernetes Publisher
parent 2698e8276e
commit 6a008ec216
7 changed files with 354 additions and 15 deletions

View File

@ -199,14 +199,18 @@ func newAuthenticator(c *cache, isTerminalFunc func(int) bool, config *api.ExecC
now: time.Now,
environ: os.Environ,
defaultDialer: defaultDialer,
connTracker: connTracker,
connTracker: connTracker,
}
for _, env := range config.Env {
a.env = append(a.env, env.Name+"="+env.Value)
}
// these functions are made comparable and stored in the cache so that repeated clientset
// construction with the same rest.Config results in a single TLS cache and Authenticator
a.getCert = &transport.GetCertHolder{GetCert: a.cert}
a.dial = &transport.DialHolder{Dial: defaultDialer.DialContext}
return c.put(key, a), nil
}
@ -261,8 +265,6 @@ type Authenticator struct {
now func() time.Time
environ func() []string
// defaultDialer is used for clients which don't specify a custom dialer
defaultDialer *connrotation.Dialer
// connTracker tracks all connections opened that we need to close when rotating a client certificate
connTracker *connrotation.ConnectionTracker
@ -273,6 +275,12 @@ type Authenticator struct {
mu sync.Mutex
cachedCreds *credentials
exp time.Time
// getCert makes Authenticator.cert comparable to support TLS config caching
getCert *transport.GetCertHolder
// dial is used for clients which do not specify a custom dialer
// it is comparable to support TLS config caching
dial *transport.DialHolder
}
type credentials struct {
@ -300,18 +308,20 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
if c.HasCertCallback() {
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
}
c.TLS.GetCert = a.cert
c.TLS.GetCert = a.getCert.GetCert
c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching
var d *connrotation.Dialer
if c.Dial != nil {
// if c has a custom dialer, we have to wrap it
d = connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
// TLS config caching is not supported for this config
d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
c.Dial = d.DialContext
c.DialHolder = nil
} else {
d = a.defaultDialer
c.Dial = a.dial.Dial
c.DialHolder = a.dial // comparable for TLS config caching
}
c.Dial = d.DialContext
return nil
}

View File

@ -0,0 +1,106 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package exec_test // separate package to prevent circular import
import (
"context"
"testing"
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilnet "k8s.io/apimachinery/pkg/util/net"
clientset "k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
)
// TestExecTLSCache asserts the semantics of the TLS cache when exec auth is used.
//
// In particular, when:
// - multiple identical rest configs exist as distinct objects, and
// - these rest configs use exec auth, and
// - these rest configs are used to create distinct clientsets, then
//
// the underlying TLS config is shared between those clientsets.
func TestExecTLSCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
t.Cleanup(cancel)
config1 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client1 := clientset.NewForConfigOrDie(config1)
config2 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client2 := clientset.NewForConfigOrDie(config2)
config3 := &rest.Config{
Host: "https://localhost",
ExecProvider: &clientcmdapi.ExecConfig{
Command: "./testdata/test-plugin.sh",
Args: []string{"make this exec auth different"},
APIVersion: "client.authentication.k8s.io/v1",
InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode,
},
}
client3 := clientset.NewForConfigOrDie(config3)
_, _ = client1.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
_, _ = client2.CoreV1().Namespaces().List(ctx, metav1.ListOptions{})
_, _ = client3.CoreV1().PersistentVolumes().List(ctx, metav1.ListOptions{})
rt1 := client1.RESTClient().(*rest.RESTClient).Client.Transport
rt2 := client2.RESTClient().(*rest.RESTClient).Client.Transport
rt3 := client3.RESTClient().(*rest.RESTClient).Client.Transport
tlsConfig1, err := utilnet.TLSClientConfig(rt1)
if err != nil {
t.Fatal(err)
}
tlsConfig2, err := utilnet.TLSClientConfig(rt2)
if err != nil {
t.Fatal(err)
}
tlsConfig3, err := utilnet.TLSClientConfig(rt3)
if err != nil {
t.Fatal(err)
}
if tlsConfig1 == nil || tlsConfig2 == nil || tlsConfig3 == nil {
t.Fatal("expected non-nil TLS configs")
}
if tlsConfig1 != tlsConfig2 {
t.Fatal("expected the same TLS config for matching exec config via rest config")
}
if tlsConfig1 == tlsConfig3 {
t.Fatal("expected different TLS config for non-matching exec config via rest config")
}
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport
import (
"context"
"fmt"
"net"
"net/http"
@ -55,6 +56,9 @@ type tlsCacheKey struct {
serverName string
nextProtos string
disableCompression bool
// these functions are wrapped to allow them to be used as map keys
getCert *GetCertHolder
dial *DialHolder
}
func (t tlsCacheKey) String() string {
@ -62,7 +66,8 @@ func (t tlsCacheKey) String() string {
if len(t.keyData) > 0 {
keyText = "<redacted>"
}
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression)
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
}
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
@ -92,8 +97,10 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
return http.DefaultTransport, nil
}
dial := config.Dial
if dial == nil {
var dial func(ctx context.Context, network, address string) (net.Conn, error)
if config.Dial != nil {
dial = config.Dial
} else {
dial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
@ -138,10 +145,18 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
return tlsCacheKey{}, false, err
}
if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil {
if c.Proxy != nil {
// cannot determine equality for functions
return tlsCacheKey{}, false, nil
}
if c.Dial != nil && c.DialHolder == nil {
// cannot determine equality for dial function that doesn't have non-nil DialHolder set as well
return tlsCacheKey{}, false, nil
}
if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil {
// cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well
return tlsCacheKey{}, false, nil
}
k := tlsCacheKey{
insecure: c.TLS.Insecure,
@ -149,6 +164,8 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
serverName: c.TLS.ServerName,
nextProtos: strings.Join(c.TLS.NextProtos, ","),
disableCompression: c.DisableCompression,
getCert: c.TLS.GetCertHolder,
dial: c.DialHolder,
}
if c.TLS.ReloadTLSFiles {

View File

@ -21,6 +21,7 @@ import (
"crypto/tls"
"net"
"net/http"
"net/url"
"testing"
)
@ -58,16 +59,24 @@ func TestTLSConfigKey(t *testing.T) {
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
if keyA != (tlsCacheKey{}) {
t.Errorf("Expected empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
}
}
// Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
getCert := func() (*tls.Certificate, error) { return nil, nil }
getCertHolder := &GetCertHolder{GetCert: getCert}
uniqueConfigurations := map[string]*Config{
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
"no tls": {},
"dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}},
"dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
@ -128,6 +137,13 @@ func TestTLSConfigKey(t *testing.T) {
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
"getCert3": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: getCert,
GetCertHolder: getCertHolder,
},
},
"getCert1, key 2": {
TLS: TLSConfig{
KeyData: []byte{2},

View File

@ -68,7 +68,11 @@ type Config struct {
WrapTransport WrapperFunc
// Dial specifies the dial function for creating unencrypted TCP connections.
// If specified, this transport will be non-cacheable unless DialHolder is also set.
Dial func(ctx context.Context, network, address string) (net.Conn, error)
// DialHolder can be populated to make transport configs cacheable.
// If specified, DialHolder.Dial must be equal to Dial.
DialHolder *DialHolder
// Proxy is the proxy func to be used for all requests made by this
// transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy
@ -78,6 +82,11 @@ type Config struct {
Proxy func(*http.Request) (*url.URL, error)
}
// DialHolder is used to make the wrapped function comparable so that it can be used as a map key.
type DialHolder struct {
Dial func(ctx context.Context, network, address string) (net.Conn, error)
}
// ImpersonationConfig has all the available impersonation options
type ImpersonationConfig struct {
// UserName matches user.Info.GetName()
@ -143,5 +152,15 @@ type TLSConfig struct {
// To use only http/1.1, set to ["http/1.1"].
NextProtos []string
GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
// Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
// If specified, this transport is non-cacheable unless CertHolder is populated.
GetCert func() (*tls.Certificate, error)
// CertHolder can be populated to make transport configs that set GetCert cacheable.
// If set, CertHolder.GetCert must be equal to GetCert.
GetCertHolder *GetCertHolder
}
// GetCertHolder is used to make the wrapped function comparable so that it can be used as a map key.
type GetCertHolder struct {
GetCert func() (*tls.Certificate, error)
}

View File

@ -24,6 +24,7 @@ import (
"fmt"
"net/http"
"os"
"reflect"
"sync"
"time"
@ -39,6 +40,10 @@ func New(config *Config) (http.RoundTripper, error) {
return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
}
if !isValidHolders(config) {
return nil, fmt.Errorf("misconfigured holder for dialer or cert callback")
}
var (
rt http.RoundTripper
err error
@ -56,6 +61,26 @@ func New(config *Config) (http.RoundTripper, error) {
return HTTPWrappersForConfig(config, rt)
}
func isValidHolders(config *Config) bool {
if config.TLS.GetCertHolder != nil {
if config.TLS.GetCertHolder.GetCert == nil ||
config.TLS.GetCert == nil ||
reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() {
return false
}
}
if config.DialHolder != nil {
if config.DialHolder.Dial == nil ||
config.Dial == nil ||
reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() {
return false
}
}
return true
}
// TLSConfigFor returns a tls.Config that will provide the transport level security defined
// by the provided Config. Will return nil if no transport level security is requested.
func TLSConfigFor(c *Config) (*tls.Config, error) {

View File

@ -21,6 +21,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"testing"
)
@ -94,6 +95,13 @@ stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa
)
func TestNew(t *testing.T) {
globalGetCert := &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
}
globalDial := &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
}
testCases := map[string]struct {
Config *Config
Err bool
@ -255,6 +263,144 @@ func TestNew(t *testing.T) {
},
},
},
"nil holders and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
TLS: false,
TLSCert: false,
TLSErr: false,
Default: true,
Insecure: false,
DefaultRoots: false,
},
"nil holders and non-nil regular get cert": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: true,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"nil holders and non-nil regular dial": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: false,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"non-nil dial holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
},
},
Err: true,
},
"non-nil cert holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil holders+internal and non-nil regular with correct address": {
Config: &Config{
TLS: TLSConfig{
GetCert: globalGetCert.GetCert,
GetCertHolder: globalGetCert,
},
Dial: globalDial.Dial,
DialHolder: globalDial,
},
Err: false,
TLS: true,
TLSCert: true,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
}
for k, testCase := range testCases {
t.Run(k, func(t *testing.T) {