don't cache transports for incomparable configs

Co-authored-by: Jordan Liggitt <liggitt@google.com>

Kubernetes-commit: 0765ba8e54f0c9e5f221e505a24759fa18beaf2e
This commit is contained in:
Haowei Cai 2020-10-07 15:44:27 -07:00 committed by Kubernetes Publisher
parent 27421eae1f
commit a3299cf8e9
2 changed files with 47 additions and 48 deletions

View File

@ -47,12 +47,9 @@ type tlsCacheKey struct {
keyData string keyData string
certFile string certFile string
keyFile string keyFile string
getCert string
serverName string serverName string
nextProtos string nextProtos string
dial string
disableCompression bool disableCompression bool
proxy string
} }
func (t tlsCacheKey) String() string { func (t tlsCacheKey) String() string {
@ -60,22 +57,24 @@ func (t tlsCacheKey) String() string {
if len(t.keyData) > 0 { if len(t.keyData) > 0 {
keyText = "<redacted>" keyText = "<redacted>"
} }
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, getCert: %s, serverName:%s, dial:%s disableCompression:%t, proxy: %s", t.insecure, t.caData, t.certData, keyText, t.getCert, t.serverName, t.dial, t.disableCompression, t.proxy) return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression)
} }
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
key, err := tlsConfigKey(config) key, canCache, err := tlsConfigKey(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Ensure we only create a single transport for the given TLS options if canCache {
c.mu.Lock() // Ensure we only create a single transport for the given TLS options
defer c.mu.Unlock() c.mu.Lock()
defer c.mu.Unlock()
// See if we already have a custom transport for this config // See if we already have a custom transport for this config
if t, ok := c.transports[key]; ok { if t, ok := c.transports[key]; ok {
return t, nil return t, nil
}
} }
// Get the TLS options for this client config // Get the TLS options for this client config
@ -110,8 +109,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
proxy = config.Proxy proxy = config.Proxy
} }
// Cache a single transport for these options transport := utilnet.SetTransportDefaults(&http.Transport{
c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{
Proxy: proxy, Proxy: proxy,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
@ -119,24 +117,33 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
DialContext: dial, DialContext: dial,
DisableCompression: config.DisableCompression, DisableCompression: config.DisableCompression,
}) })
return c.transports[key], nil
if canCache {
// Cache a single transport for these options
c.transports[key] = transport
}
return transport, nil
} }
// tlsConfigKey returns a unique key for tls.Config objects returned from TLSConfigFor // tlsConfigKey returns a unique key for tls.Config objects returned from TLSConfigFor
func tlsConfigKey(c *Config) (tlsCacheKey, error) { func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
// Make sure ca/key/cert content is loaded // Make sure ca/key/cert content is loaded
if err := loadTLSFiles(c); err != nil { if err := loadTLSFiles(c); err != nil {
return tlsCacheKey{}, err return tlsCacheKey{}, false, err
} }
if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil {
// cannot determine equality for functions
return tlsCacheKey{}, false, nil
}
k := tlsCacheKey{ k := tlsCacheKey{
insecure: c.TLS.Insecure, insecure: c.TLS.Insecure,
caData: string(c.TLS.CAData), caData: string(c.TLS.CAData),
getCert: fmt.Sprintf("%p", c.TLS.GetCert),
serverName: c.TLS.ServerName, serverName: c.TLS.ServerName,
nextProtos: strings.Join(c.TLS.NextProtos, ","), nextProtos: strings.Join(c.TLS.NextProtos, ","),
dial: fmt.Sprintf("%p", c.Dial),
disableCompression: c.DisableCompression, disableCompression: c.DisableCompression,
proxy: fmt.Sprintf("%p", c.Proxy),
} }
if c.TLS.ReloadTLSFiles { if c.TLS.ReloadTLSFiles {
@ -147,5 +154,5 @@ func tlsConfigKey(c *Config) (tlsCacheKey, error) {
k.keyData = string(c.TLS.KeyData) k.keyData = string(c.TLS.KeyData)
} }
return k, nil return k, true, nil
} }

View File

@ -21,7 +21,6 @@ import (
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
"net/url"
"testing" "testing"
) )
@ -37,16 +36,24 @@ func TestTLSConfigKey(t *testing.T) {
} }
for nameA, valueA := range identicalConfigurations { for nameA, valueA := range identicalConfigurations {
for nameB, valueB := range identicalConfigurations { for nameB, valueB := range identicalConfigurations {
keyA, err := tlsConfigKey(valueA) keyA, canCache, err := tlsConfigKey(valueA)
if err != nil { if err != nil {
t.Errorf("Unexpected error for %q: %v", nameA, err) t.Errorf("Unexpected error for %q: %v", nameA, err)
continue continue
} }
keyB, err := tlsConfigKey(valueB) if !canCache {
t.Errorf("Unexpected canCache=false")
continue
}
keyB, canCache, err := tlsConfigKey(valueB)
if err != nil { if err != nil {
t.Errorf("Unexpected error for %q: %v", nameB, err) t.Errorf("Unexpected error for %q: %v", nameB, err)
continue continue
} }
if !canCache {
t.Errorf("Unexpected canCache=false")
continue
}
if keyA != keyB { if keyA != keyB {
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue continue
@ -132,12 +139,12 @@ func TestTLSConfigKey(t *testing.T) {
} }
for nameA, valueA := range uniqueConfigurations { for nameA, valueA := range uniqueConfigurations {
for nameB, valueB := range uniqueConfigurations { for nameB, valueB := range uniqueConfigurations {
keyA, err := tlsConfigKey(valueA) keyA, canCacheA, err := tlsConfigKey(valueA)
if err != nil { if err != nil {
t.Errorf("Unexpected error for %q: %v", nameA, err) t.Errorf("Unexpected error for %q: %v", nameA, err)
continue continue
} }
keyB, err := tlsConfigKey(valueB) keyB, canCacheB, err := tlsConfigKey(valueB)
if err != nil { if err != nil {
t.Errorf("Unexpected error for %q: %v", nameB, err) t.Errorf("Unexpected error for %q: %v", nameB, err)
continue continue
@ -148,33 +155,18 @@ func TestTLSConfigKey(t *testing.T) {
if keyA != keyB { if keyA != keyB {
t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
} }
if canCacheA != canCacheB {
t.Errorf("Expected identical canCache %q and %q, got:\n\t%v\n\t%v", nameA, nameB, canCacheA, canCacheB)
}
continue continue
} }
if keyA == keyB { if canCacheA && canCacheB {
t.Errorf("Expected unique cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) if keyA == keyB {
continue t.Errorf("Expected unique cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB)
continue
}
} }
} }
} }
} }
func TestTLSConfigKeyFuncPtr(t *testing.T) {
keys := make(map[tlsCacheKey]struct{})
makeKey := func(p func(*http.Request) (*url.URL, error)) tlsCacheKey {
key, err := tlsConfigKey(&Config{Proxy: p})
if err != nil {
t.Fatalf("Unexpected error creating cache key: %v", err)
}
return key
}
keys[makeKey(http.ProxyFromEnvironment)] = struct{}{}
keys[makeKey(http.ProxyFromEnvironment)] = struct{}{}
keys[makeKey(http.ProxyURL(nil))] = struct{}{}
keys[makeKey(nil)] = struct{}{}
if got, want := len(keys), 3; got != want {
t.Fatalf("Unexpected number of keys: got=%d want=%d", got, want)
}
}