From b7e3da037231b24bd3f86126e1600add1f0b1b32 Mon Sep 17 00:00:00 2001 From: Haowei Cai Date: Wed, 7 Oct 2020 15:44:27 -0700 Subject: [PATCH] don't cache transports for incomparable configs Co-authored-by: Jordan Liggitt Kubernetes-commit: a5ad745376432db81491de7e9a6102e74ba45c26 --- transport/cache.go | 45 ++++++++++++++++++++++++----------------- transport/cache_test.go | 27 ++++++++++++++++++------- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/transport/cache.go b/transport/cache.go index 980d36ae1..c0651bbaa 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -44,10 +44,8 @@ type tlsCacheKey struct { caData string certData string keyData string - getCert string serverName string nextProtos string - dial string disableCompression bool } @@ -56,22 +54,24 @@ func (t tlsCacheKey) String() string { if len(t.keyData) > 0 { keyText = "" } - return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, getCert: %s, serverName:%s, dial:%s disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.getCert, t.serverName, t.dial, t.disableCompression) + return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression) } func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { - key, err := tlsConfigKey(config) + key, canCache, err := tlsConfigKey(config) if err != nil { return nil, err } - // Ensure we only create a single transport for the given TLS options - c.mu.Lock() - defer c.mu.Unlock() + if canCache { + // Ensure we only create a single transport for the given TLS options + c.mu.Lock() + defer c.mu.Unlock() - // See if we already have a custom transport for this config - if t, ok := c.transports[key]; ok { - return t, nil + // See if we already have a custom transport for this config + if t, ok := c.transports[key]; ok { + return t, nil + } } // Get the TLS options for this client config @@ -91,8 +91,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { KeepAlive: 30 * time.Second, }).DialContext } - // Cache a single transport for these options - c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{ + transport := utilnet.SetTransportDefaults(&http.Transport{ Proxy: http.ProxyFromEnvironment, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, @@ -100,24 +99,34 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { DialContext: dial, DisableCompression: config.DisableCompression, }) - return c.transports[key], nil + + if canCache { + // Cache a single transport for these options + c.transports[key] = transport + } + + return transport, nil } // tlsConfigKey returns a unique key for tls.Config objects returned from TLSConfigFor -func tlsConfigKey(c *Config) (tlsCacheKey, error) { +func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { // Make sure ca/key/cert content is loaded if err := loadTLSFiles(c); err != nil { - return tlsCacheKey{}, err + return tlsCacheKey{}, false, err } + + if c.TLS.GetCert != nil || c.Dial != nil { + // cannot determine equality for functions + return tlsCacheKey{}, false, nil + } + return tlsCacheKey{ insecure: c.TLS.Insecure, caData: string(c.TLS.CAData), certData: string(c.TLS.CertData), keyData: string(c.TLS.KeyData), - getCert: fmt.Sprintf("%p", c.TLS.GetCert), serverName: c.TLS.ServerName, nextProtos: strings.Join(c.TLS.NextProtos, ","), - dial: fmt.Sprintf("%p", c.Dial), disableCompression: c.DisableCompression, - }, nil + }, true, nil } diff --git a/transport/cache_test.go b/transport/cache_test.go index 8b9779e68..c6d06fcab 100644 --- a/transport/cache_test.go +++ b/transport/cache_test.go @@ -36,16 +36,24 @@ func TestTLSConfigKey(t *testing.T) { } for nameA, valueA := range identicalConfigurations { for nameB, valueB := range identicalConfigurations { - keyA, err := tlsConfigKey(valueA) + keyA, canCache, err := tlsConfigKey(valueA) if err != nil { t.Errorf("Unexpected error for %q: %v", nameA, err) continue } - keyB, err := tlsConfigKey(valueB) + if !canCache { + t.Errorf("Unexpected canCache=false") + continue + } + keyB, canCache, err := tlsConfigKey(valueB) if err != nil { t.Errorf("Unexpected error for %q: %v", nameB, err) continue } + if !canCache { + t.Errorf("Unexpected canCache=false") + continue + } if keyA != keyB { t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) continue @@ -131,12 +139,12 @@ func TestTLSConfigKey(t *testing.T) { } for nameA, valueA := range uniqueConfigurations { for nameB, valueB := range uniqueConfigurations { - keyA, err := tlsConfigKey(valueA) + keyA, canCacheA, err := tlsConfigKey(valueA) if err != nil { t.Errorf("Unexpected error for %q: %v", nameA, err) continue } - keyB, err := tlsConfigKey(valueB) + keyB, canCacheB, err := tlsConfigKey(valueB) if err != nil { t.Errorf("Unexpected error for %q: %v", nameB, err) continue @@ -147,12 +155,17 @@ func TestTLSConfigKey(t *testing.T) { if keyA != keyB { t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) } + if canCacheA != canCacheB { + t.Errorf("Expected identical canCache %q and %q, got:\n\t%v\n\t%v", nameA, nameB, canCacheA, canCacheB) + } continue } - if keyA == keyB { - t.Errorf("Expected unique cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) - continue + if canCacheA && canCacheB { + if keyA == keyB { + t.Errorf("Expected unique cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) + continue + } } } }