client-go/transport: drop Dial and GetCert fields in favor of Holders

Signed-off-by: Monis Khan <mok@microsoft.com>

Kubernetes-commit: 3313a70d5bcc40a39f99f482c18effc9de6072ba
This commit is contained in:
Monis Khan 2022-09-09 08:06:01 -04:00 committed by Kubernetes Publisher
parent eecd3e52a3
commit 5dab9a0b84
8 changed files with 101 additions and 166 deletions

View File

@ -308,17 +308,18 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error {
if c.HasCertCallback() {
return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set")
}
c.TLS.GetCert = a.getCert.GetCert
c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching
if c.Dial != nil {
if c.DialHolder != nil {
if c.DialHolder.Dial == nil {
return errors.New("invalid transport.Config.DialHolder: wrapped Dial function is nil")
}
// if c has a custom dialer, we have to wrap it
// TLS config caching is not supported for this config
d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker)
c.Dial = d.DialContext
c.DialHolder = nil
d := connrotation.NewDialerWithTracker(c.DialHolder.Dial, a.connTracker)
c.DialHolder = &transport.DialHolder{Dial: d.DialContext}
} else {
c.Dial = a.dial.Dial
c.DialHolder = a.dial // comparable for TLS config caching
}

View File

@ -1005,7 +1005,7 @@ func TestAuthorizationHeaderPresentCancelsExecAction(t *testing.T) {
cert := func() (*tls.Certificate, error) {
return nil, nil
}
tc := &transport.Config{TLS: transport.TLSConfig{Insecure: true, GetCert: cert}}
tc := &transport.Config{TLS: transport.TLSConfig{Insecure: true, GetCertHolder: &transport.GetCertHolder{GetCert: cert}}}
test.setTransportConfig(tc)
if err := a.UpdateTransportConfig(tc); err != nil {

View File

@ -108,10 +108,13 @@ func (c *Config) TransportConfig() (*transport.Config, error) {
Groups: c.Impersonate.Groups,
Extra: c.Impersonate.Extra,
},
Dial: c.Dial,
Proxy: c.Proxy,
}
if c.Dial != nil {
conf.DialHolder = &transport.DialHolder{Dial: c.Dial}
}
if c.ExecProvider != nil && c.AuthProvider != nil {
return nil, errors.New("execProvider and authProvider cannot be used in combination")
}

View File

@ -93,13 +93,13 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
return nil, err
}
// The options didn't require a custom TLS config
if tlsConfig == nil && config.Dial == nil && config.Proxy == nil {
if tlsConfig == nil && config.DialHolder == nil && config.Proxy == nil {
return http.DefaultTransport, nil
}
var dial func(ctx context.Context, network, address string) (net.Conn, error)
if config.Dial != nil {
dial = config.Dial
if config.DialHolder != nil {
dial = config.DialHolder.Dial
} else {
dial = (&net.Dialer{
Timeout: 30 * time.Second,
@ -149,14 +149,6 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
// cannot determine equality for functions
return tlsCacheKey{}, false, nil
}
if c.Dial != nil && c.DialHolder == nil {
// cannot determine equality for dial function that doesn't have non-nil DialHolder set as well
return tlsCacheKey{}, false, nil
}
if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil {
// cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well
return tlsCacheKey{}, false, nil
}
k := tlsCacheKey{
insecure: c.TLS.Insecure,

View File

@ -22,6 +22,7 @@ import (
"net"
"net/http"
"net/url"
"reflect"
"testing"
)
@ -68,15 +69,12 @@ func TestTLSConfigKey(t *testing.T) {
// Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
getCert := func() (*tls.Certificate, error) { return nil, nil }
getCertHolder := &GetCertHolder{GetCert: getCert}
getCert := &GetCertHolder{GetCert: func() (*tls.Certificate, error) { return nil, nil }}
uniqueConfigurations := map[string]*Config{
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
"no tls": {},
"dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}},
"dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
"dialer": {DialHolder: &DialHolder{Dial: dialer.DialContext}},
"dialer2": {DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
@ -127,27 +125,20 @@ func TestTLSConfigKey(t *testing.T) {
},
"getCert1": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: getCert,
KeyData: []byte{1},
GetCertHolder: getCert,
},
},
"getCert2": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
"getCert3": {
TLS: TLSConfig{
KeyData: []byte{1},
GetCert: getCert,
GetCertHolder: getCertHolder,
GetCertHolder: &GetCertHolder{GetCert: func() (*tls.Certificate, error) { return nil, nil }},
},
},
"getCert1, key 2": {
TLS: TLSConfig{
KeyData: []byte{2},
GetCert: getCert,
KeyData: []byte{2},
GetCertHolder: getCert,
},
},
"http2, http1.1": {TLS: TLSConfig{NextProtos: []string{"h2", "http/1.1"}}},
@ -166,6 +157,17 @@ func TestTLSConfigKey(t *testing.T) {
continue
}
shouldCacheA := valueA.Proxy == nil
if shouldCacheA != canCacheA {
t.Errorf("Unexpected canCache=false for " + nameA)
}
configIsNotEmpty := !reflect.DeepEqual(*valueA, Config{})
if keyA == (tlsCacheKey{}) && shouldCacheA && configIsNotEmpty {
t.Errorf("Expected non-empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
// Make sure we get the same key on the same config
if nameA == nameB {
if keyA != keyB {

View File

@ -67,11 +67,8 @@ type Config struct {
// instead of setting this value directly.
WrapTransport WrapperFunc
// Dial specifies the dial function for creating unencrypted TCP connections.
// If specified, this transport will be non-cacheable unless DialHolder is also set.
Dial func(ctx context.Context, network, address string) (net.Conn, error)
// DialHolder can be populated to make transport configs cacheable.
// If specified, DialHolder.Dial must be equal to Dial.
// DialHolder specifies the dial function for creating unencrypted TCP connections.
// This struct indirection is used to make transport configs cacheable.
DialHolder *DialHolder
// Proxy is the proxy func to be used for all requests made by this
@ -121,7 +118,7 @@ func (c *Config) HasCertAuth() bool {
// HasCertCallback returns whether the configuration has certificate callback or not.
func (c *Config) HasCertCallback() bool {
return c.TLS.GetCert != nil
return c.TLS.GetCertHolder != nil
}
// Wrap adds a transport middleware function that will give the caller
@ -153,10 +150,7 @@ type TLSConfig struct {
NextProtos []string
// Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field.
// If specified, this transport is non-cacheable unless CertHolder is populated.
GetCert func() (*tls.Certificate, error)
// CertHolder can be populated to make transport configs that set GetCert cacheable.
// If set, CertHolder.GetCert must be equal to GetCert.
// This struct indirection is used to make transport configs cacheable.
GetCertHolder *GetCertHolder
}

View File

@ -24,7 +24,6 @@ import (
"fmt"
"net/http"
"os"
"reflect"
"sync"
"time"
@ -62,20 +61,12 @@ func New(config *Config) (http.RoundTripper, error) {
}
func isValidHolders(config *Config) bool {
if config.TLS.GetCertHolder != nil {
if config.TLS.GetCertHolder.GetCert == nil ||
config.TLS.GetCert == nil ||
reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() {
return false
}
if config.TLS.GetCertHolder != nil && config.TLS.GetCertHolder.GetCert == nil {
return false
}
if config.DialHolder != nil {
if config.DialHolder.Dial == nil ||
config.Dial == nil ||
reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() {
return false
}
if config.DialHolder != nil && config.DialHolder.Dial == nil {
return false
}
return true
@ -141,7 +132,7 @@ func TLSConfigFor(c *Config) (*tls.Config, error) {
return dynamicCertLoader()
}
if c.HasCertCallback() {
cert, err := c.TLS.GetCert()
cert, err := c.TLS.GetCertHolder.GetCert()
if err != nil {
return nil, err
}

View File

@ -217,9 +217,11 @@ func TestNew(t *testing.T) {
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
GetCert: func() (*tls.Certificate, error) {
crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
return &crt, err
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) {
crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
return &crt, err
},
},
},
},
@ -231,8 +233,10 @@ func TestNew(t *testing.T) {
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
GetCert: func() (*tls.Certificate, error) {
return nil, errors.New("GetCert failure")
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) {
return nil, errors.New("GetCert failure")
},
},
},
},
@ -243,8 +247,10 @@ func TestNew(t *testing.T) {
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
GetCert: func() (*tls.Certificate, error) {
return nil, nil
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) {
return nil, nil
},
},
CertData: []byte(certData),
KeyData: []byte(keyData),
@ -257,19 +263,19 @@ func TestNew(t *testing.T) {
Config: &Config{
TLS: TLSConfig{
CAData: []byte(rootCACert),
GetCert: func() (*tls.Certificate, error) {
return nil, nil
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) {
return nil, nil
},
},
},
},
},
"nil holders and nil regular": {
"nil holders": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: nil,
},
Err: false,
@ -280,13 +286,48 @@ func TestNew(t *testing.T) {
Insecure: false,
DefaultRoots: false,
},
"nil holders and non-nil regular get cert": {
"non-nil dial holder and nil internal": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: nil,
},
Dial: nil,
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and nil internal": {
Config: &Config{
TLS: TLSConfig{
GetCertHolder: &GetCertHolder{},
},
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder+internal": {
Config: &Config{
TLS: TLSConfig{
GetCertHolder: nil,
},
DialHolder: &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
},
},
Err: false,
TLS: true,
TLSCert: false,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"non-nil cert holder+internal": {
Config: &Config{
TLS: TLSConfig{
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
DialHolder: nil,
},
Err: false,
@ -297,100 +338,11 @@ func TestNew(t *testing.T) {
Insecure: false,
DefaultRoots: true,
},
"nil holders and non-nil regular dial": {
"non-nil holders+internal with global address": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: nil,
},
Err: false,
TLS: true,
TLSCert: false,
TLSErr: false,
Default: false,
Insecure: false,
DefaultRoots: true,
},
"non-nil dial holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: nil,
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{},
},
Err: true,
},
"non-nil cert holder and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil dial holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: nil,
GetCertHolder: nil,
},
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
DialHolder: &DialHolder{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil },
},
},
Err: true,
},
"non-nil cert holder+internal and non-nil regular": {
Config: &Config{
TLS: TLSConfig{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
GetCertHolder: &GetCertHolder{
GetCert: func() (*tls.Certificate, error) { return nil, nil },
},
},
Dial: nil,
DialHolder: nil,
},
Err: true,
},
"non-nil holders+internal and non-nil regular with correct address": {
Config: &Config{
TLS: TLSConfig{
GetCert: globalGetCert.GetCert,
GetCertHolder: globalGetCert,
},
Dial: globalDial.Dial,
DialHolder: globalDial,
},
Err: false,