diff --git a/plugin/pkg/client/auth/exec/exec.go b/plugin/pkg/client/auth/exec/exec.go index 73876f68..5331b237 100644 --- a/plugin/pkg/client/auth/exec/exec.go +++ b/plugin/pkg/client/auth/exec/exec.go @@ -308,17 +308,18 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error { if c.HasCertCallback() { return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set") } - c.TLS.GetCert = a.getCert.GetCert c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching - if c.Dial != nil { + if c.DialHolder != nil { + if c.DialHolder.Dial == nil { + return errors.New("invalid transport.Config.DialHolder: wrapped Dial function is nil") + } + // if c has a custom dialer, we have to wrap it // TLS config caching is not supported for this config - d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker) - c.Dial = d.DialContext - c.DialHolder = nil + d := connrotation.NewDialerWithTracker(c.DialHolder.Dial, a.connTracker) + c.DialHolder = &transport.DialHolder{Dial: d.DialContext} } else { - c.Dial = a.dial.Dial c.DialHolder = a.dial // comparable for TLS config caching } diff --git a/plugin/pkg/client/auth/exec/exec_test.go b/plugin/pkg/client/auth/exec/exec_test.go index b56dd23b..ea6ef781 100644 --- a/plugin/pkg/client/auth/exec/exec_test.go +++ b/plugin/pkg/client/auth/exec/exec_test.go @@ -1005,7 +1005,7 @@ func TestAuthorizationHeaderPresentCancelsExecAction(t *testing.T) { cert := func() (*tls.Certificate, error) { return nil, nil } - tc := &transport.Config{TLS: transport.TLSConfig{Insecure: true, GetCert: cert}} + tc := &transport.Config{TLS: transport.TLSConfig{Insecure: true, GetCertHolder: &transport.GetCertHolder{GetCert: cert}}} test.setTransportConfig(tc) if err := a.UpdateTransportConfig(tc); err != nil { diff --git a/rest/transport.go b/rest/transport.go index 7c38c6d9..53f986cb 100644 --- a/rest/transport.go +++ b/rest/transport.go @@ -108,10 +108,13 @@ func (c *Config) TransportConfig() (*transport.Config, error) { Groups: c.Impersonate.Groups, Extra: c.Impersonate.Extra, }, - Dial: c.Dial, Proxy: c.Proxy, } + if c.Dial != nil { + conf.DialHolder = &transport.DialHolder{Dial: c.Dial} + } + if c.ExecProvider != nil && c.AuthProvider != nil { return nil, errors.New("execProvider and authProvider cannot be used in combination") } diff --git a/transport/cache.go b/transport/cache.go index b4f8dab0..9d2889d1 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -93,13 +93,13 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { return nil, err } // The options didn't require a custom TLS config - if tlsConfig == nil && config.Dial == nil && config.Proxy == nil { + if tlsConfig == nil && config.DialHolder == nil && config.Proxy == nil { return http.DefaultTransport, nil } var dial func(ctx context.Context, network, address string) (net.Conn, error) - if config.Dial != nil { - dial = config.Dial + if config.DialHolder != nil { + dial = config.DialHolder.Dial } else { dial = (&net.Dialer{ Timeout: 30 * time.Second, @@ -149,14 +149,6 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { // cannot determine equality for functions return tlsCacheKey{}, false, nil } - if c.Dial != nil && c.DialHolder == nil { - // cannot determine equality for dial function that doesn't have non-nil DialHolder set as well - return tlsCacheKey{}, false, nil - } - if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil { - // cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well - return tlsCacheKey{}, false, nil - } k := tlsCacheKey{ insecure: c.TLS.Insecure, diff --git a/transport/cache_test.go b/transport/cache_test.go index 87d070bb..f2e455cc 100644 --- a/transport/cache_test.go +++ b/transport/cache_test.go @@ -22,6 +22,7 @@ import ( "net" "net/http" "net/url" + "reflect" "testing" ) @@ -68,15 +69,12 @@ func TestTLSConfigKey(t *testing.T) { // Make sure config fields that affect the tls config affect the cache key dialer := net.Dialer{} - getCert := func() (*tls.Certificate, error) { return nil, nil } - getCertHolder := &GetCertHolder{GetCert: getCert} + getCert := &GetCertHolder{GetCert: func() (*tls.Certificate, error) { return nil, nil }} uniqueConfigurations := map[string]*Config{ "proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }}, "no tls": {}, - "dialer": {Dial: dialer.DialContext}, - "dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}, - "dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}}, - "dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}}, + "dialer": {DialHolder: &DialHolder{Dial: dialer.DialContext}}, + "dialer2": {DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}}, "insecure": {TLS: TLSConfig{Insecure: true}}, "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, @@ -127,27 +125,20 @@ func TestTLSConfigKey(t *testing.T) { }, "getCert1": { TLS: TLSConfig{ - KeyData: []byte{1}, - GetCert: getCert, + KeyData: []byte{1}, + GetCertHolder: getCert, }, }, "getCert2": { - TLS: TLSConfig{ - KeyData: []byte{1}, - GetCert: func() (*tls.Certificate, error) { return nil, nil }, - }, - }, - "getCert3": { TLS: TLSConfig{ KeyData: []byte{1}, - GetCert: getCert, - GetCertHolder: getCertHolder, + GetCertHolder: &GetCertHolder{GetCert: func() (*tls.Certificate, error) { return nil, nil }}, }, }, "getCert1, key 2": { TLS: TLSConfig{ - KeyData: []byte{2}, - GetCert: getCert, + KeyData: []byte{2}, + GetCertHolder: getCert, }, }, "http2, http1.1": {TLS: TLSConfig{NextProtos: []string{"h2", "http/1.1"}}}, @@ -166,6 +157,17 @@ func TestTLSConfigKey(t *testing.T) { continue } + shouldCacheA := valueA.Proxy == nil + if shouldCacheA != canCacheA { + t.Errorf("Unexpected canCache=false for " + nameA) + } + + configIsNotEmpty := !reflect.DeepEqual(*valueA, Config{}) + if keyA == (tlsCacheKey{}) && shouldCacheA && configIsNotEmpty { + t.Errorf("Expected non-empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) + continue + } + // Make sure we get the same key on the same config if nameA == nameB { if keyA != keyB { diff --git a/transport/config.go b/transport/config.go index fd853c0b..d8a3d64b 100644 --- a/transport/config.go +++ b/transport/config.go @@ -67,11 +67,8 @@ type Config struct { // instead of setting this value directly. WrapTransport WrapperFunc - // Dial specifies the dial function for creating unencrypted TCP connections. - // If specified, this transport will be non-cacheable unless DialHolder is also set. - Dial func(ctx context.Context, network, address string) (net.Conn, error) - // DialHolder can be populated to make transport configs cacheable. - // If specified, DialHolder.Dial must be equal to Dial. + // DialHolder specifies the dial function for creating unencrypted TCP connections. + // This struct indirection is used to make transport configs cacheable. DialHolder *DialHolder // Proxy is the proxy func to be used for all requests made by this @@ -121,7 +118,7 @@ func (c *Config) HasCertAuth() bool { // HasCertCallback returns whether the configuration has certificate callback or not. func (c *Config) HasCertCallback() bool { - return c.TLS.GetCert != nil + return c.TLS.GetCertHolder != nil } // Wrap adds a transport middleware function that will give the caller @@ -153,10 +150,7 @@ type TLSConfig struct { NextProtos []string // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field. - // If specified, this transport is non-cacheable unless CertHolder is populated. - GetCert func() (*tls.Certificate, error) - // CertHolder can be populated to make transport configs that set GetCert cacheable. - // If set, CertHolder.GetCert must be equal to GetCert. + // This struct indirection is used to make transport configs cacheable. GetCertHolder *GetCertHolder } diff --git a/transport/transport.go b/transport/transport.go index 1485548a..78060719 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -24,7 +24,6 @@ import ( "fmt" "net/http" "os" - "reflect" "sync" "time" @@ -62,20 +61,12 @@ func New(config *Config) (http.RoundTripper, error) { } func isValidHolders(config *Config) bool { - if config.TLS.GetCertHolder != nil { - if config.TLS.GetCertHolder.GetCert == nil || - config.TLS.GetCert == nil || - reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() { - return false - } + if config.TLS.GetCertHolder != nil && config.TLS.GetCertHolder.GetCert == nil { + return false } - if config.DialHolder != nil { - if config.DialHolder.Dial == nil || - config.Dial == nil || - reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() { - return false - } + if config.DialHolder != nil && config.DialHolder.Dial == nil { + return false } return true @@ -141,7 +132,7 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { return dynamicCertLoader() } if c.HasCertCallback() { - cert, err := c.TLS.GetCert() + cert, err := c.TLS.GetCertHolder.GetCert() if err != nil { return nil, err } diff --git a/transport/transport_test.go b/transport/transport_test.go index e0fd2679..18804422 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -217,9 +217,11 @@ func TestNew(t *testing.T) { Config: &Config{ TLS: TLSConfig{ CAData: []byte(rootCACert), - GetCert: func() (*tls.Certificate, error) { - crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData)) - return &crt, err + GetCertHolder: &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { + crt, err := tls.X509KeyPair([]byte(certData), []byte(keyData)) + return &crt, err + }, }, }, }, @@ -231,8 +233,10 @@ func TestNew(t *testing.T) { Config: &Config{ TLS: TLSConfig{ CAData: []byte(rootCACert), - GetCert: func() (*tls.Certificate, error) { - return nil, errors.New("GetCert failure") + GetCertHolder: &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { + return nil, errors.New("GetCert failure") + }, }, }, }, @@ -243,8 +247,10 @@ func TestNew(t *testing.T) { Config: &Config{ TLS: TLSConfig{ CAData: []byte(rootCACert), - GetCert: func() (*tls.Certificate, error) { - return nil, nil + GetCertHolder: &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { + return nil, nil + }, }, CertData: []byte(certData), KeyData: []byte(keyData), @@ -257,19 +263,19 @@ func TestNew(t *testing.T) { Config: &Config{ TLS: TLSConfig{ CAData: []byte(rootCACert), - GetCert: func() (*tls.Certificate, error) { - return nil, nil + GetCertHolder: &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { + return nil, nil + }, }, }, }, }, - "nil holders and nil regular": { + "nil holders": { Config: &Config{ TLS: TLSConfig{ - GetCert: nil, GetCertHolder: nil, }, - Dial: nil, DialHolder: nil, }, Err: false, @@ -280,13 +286,48 @@ func TestNew(t *testing.T) { Insecure: false, DefaultRoots: false, }, - "nil holders and non-nil regular get cert": { + "non-nil dial holder and nil internal": { Config: &Config{ TLS: TLSConfig{ - GetCert: func() (*tls.Certificate, error) { return nil, nil }, GetCertHolder: nil, }, - Dial: nil, + DialHolder: &DialHolder{}, + }, + Err: true, + }, + "non-nil cert holder and nil internal": { + Config: &Config{ + TLS: TLSConfig{ + GetCertHolder: &GetCertHolder{}, + }, + DialHolder: nil, + }, + Err: true, + }, + "non-nil dial holder+internal": { + Config: &Config{ + TLS: TLSConfig{ + GetCertHolder: nil, + }, + DialHolder: &DialHolder{ + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + }, + }, + Err: false, + TLS: true, + TLSCert: false, + TLSErr: false, + Default: false, + Insecure: false, + DefaultRoots: true, + }, + "non-nil cert holder+internal": { + Config: &Config{ + TLS: TLSConfig{ + GetCertHolder: &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + }, + }, DialHolder: nil, }, Err: false, @@ -297,100 +338,11 @@ func TestNew(t *testing.T) { Insecure: false, DefaultRoots: true, }, - "nil holders and non-nil regular dial": { + "non-nil holders+internal with global address": { Config: &Config{ TLS: TLSConfig{ - GetCert: nil, - GetCertHolder: nil, - }, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, - DialHolder: nil, - }, - Err: false, - TLS: true, - TLSCert: false, - TLSErr: false, - Default: false, - Insecure: false, - DefaultRoots: true, - }, - "non-nil dial holder and nil regular": { - Config: &Config{ - TLS: TLSConfig{ - GetCert: nil, - GetCertHolder: nil, - }, - Dial: nil, - DialHolder: &DialHolder{}, - }, - Err: true, - }, - "non-nil cert holder and nil regular": { - Config: &Config{ - TLS: TLSConfig{ - GetCert: nil, - GetCertHolder: &GetCertHolder{}, - }, - Dial: nil, - DialHolder: nil, - }, - Err: true, - }, - "non-nil dial holder and non-nil regular": { - Config: &Config{ - TLS: TLSConfig{ - GetCert: nil, - GetCertHolder: nil, - }, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, - DialHolder: &DialHolder{}, - }, - Err: true, - }, - "non-nil cert holder and non-nil regular": { - Config: &Config{ - TLS: TLSConfig{ - GetCert: func() (*tls.Certificate, error) { return nil, nil }, - GetCertHolder: &GetCertHolder{}, - }, - Dial: nil, - DialHolder: nil, - }, - Err: true, - }, - "non-nil dial holder+internal and non-nil regular": { - Config: &Config{ - TLS: TLSConfig{ - GetCert: nil, - GetCertHolder: nil, - }, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, - DialHolder: &DialHolder{ - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, - }, - }, - Err: true, - }, - "non-nil cert holder+internal and non-nil regular": { - Config: &Config{ - TLS: TLSConfig{ - GetCert: func() (*tls.Certificate, error) { return nil, nil }, - GetCertHolder: &GetCertHolder{ - GetCert: func() (*tls.Certificate, error) { return nil, nil }, - }, - }, - Dial: nil, - DialHolder: nil, - }, - Err: true, - }, - "non-nil holders+internal and non-nil regular with correct address": { - Config: &Config{ - TLS: TLSConfig{ - GetCert: globalGetCert.GetCert, GetCertHolder: globalGetCert, }, - Dial: globalDial.Dial, DialHolder: globalDial, }, Err: false,