diff --git a/transport/cache.go b/transport/cache.go index 3ec4e193..fa2afb1f 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -47,12 +47,9 @@ type tlsCacheKey struct { keyData string certFile string keyFile string - getCert string serverName string nextProtos string - dial string disableCompression bool - proxy string } func (t tlsCacheKey) String() string { @@ -60,22 +57,24 @@ func (t tlsCacheKey) String() string { if len(t.keyData) > 0 { keyText = "" } - return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, getCert: %s, serverName:%s, dial:%s disableCompression:%t, proxy: %s", t.insecure, t.caData, t.certData, keyText, t.getCert, t.serverName, t.dial, t.disableCompression, t.proxy) + return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression) } func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { - key, err := tlsConfigKey(config) + key, canCache, err := tlsConfigKey(config) if err != nil { return nil, err } - // Ensure we only create a single transport for the given TLS options - c.mu.Lock() - defer c.mu.Unlock() + if canCache { + // Ensure we only create a single transport for the given TLS options + c.mu.Lock() + defer c.mu.Unlock() - // See if we already have a custom transport for this config - if t, ok := c.transports[key]; ok { - return t, nil + // See if we already have a custom transport for this config + if t, ok := c.transports[key]; ok { + return t, nil + } } // Get the TLS options for this client config @@ -110,8 +109,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { proxy = config.Proxy } - // Cache a single transport for these options - c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{ + transport := utilnet.SetTransportDefaults(&http.Transport{ Proxy: proxy, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, @@ -119,24 +117,33 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { DialContext: dial, DisableCompression: config.DisableCompression, }) - return c.transports[key], nil + + if canCache { + // Cache a single transport for these options + c.transports[key] = transport + } + + return transport, nil } // tlsConfigKey returns a unique key for tls.Config objects returned from TLSConfigFor -func tlsConfigKey(c *Config) (tlsCacheKey, error) { +func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { // Make sure ca/key/cert content is loaded if err := loadTLSFiles(c); err != nil { - return tlsCacheKey{}, err + return tlsCacheKey{}, false, err } + + if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil { + // cannot determine equality for functions + return tlsCacheKey{}, false, nil + } + k := tlsCacheKey{ insecure: c.TLS.Insecure, caData: string(c.TLS.CAData), - getCert: fmt.Sprintf("%p", c.TLS.GetCert), serverName: c.TLS.ServerName, nextProtos: strings.Join(c.TLS.NextProtos, ","), - dial: fmt.Sprintf("%p", c.Dial), disableCompression: c.DisableCompression, - proxy: fmt.Sprintf("%p", c.Proxy), } if c.TLS.ReloadTLSFiles { @@ -147,5 +154,5 @@ func tlsConfigKey(c *Config) (tlsCacheKey, error) { k.keyData = string(c.TLS.KeyData) } - return k, nil + return k, true, nil } diff --git a/transport/cache_test.go b/transport/cache_test.go index 11ee6253..c6d06fca 100644 --- a/transport/cache_test.go +++ b/transport/cache_test.go @@ -21,7 +21,6 @@ import ( "crypto/tls" "net" "net/http" - "net/url" "testing" ) @@ -37,16 +36,24 @@ func TestTLSConfigKey(t *testing.T) { } for nameA, valueA := range identicalConfigurations { for nameB, valueB := range identicalConfigurations { - keyA, err := tlsConfigKey(valueA) + keyA, canCache, err := tlsConfigKey(valueA) if err != nil { t.Errorf("Unexpected error for %q: %v", nameA, err) continue } - keyB, err := tlsConfigKey(valueB) + if !canCache { + t.Errorf("Unexpected canCache=false") + continue + } + keyB, canCache, err := tlsConfigKey(valueB) if err != nil { t.Errorf("Unexpected error for %q: %v", nameB, err) continue } + if !canCache { + t.Errorf("Unexpected canCache=false") + continue + } if keyA != keyB { t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) continue @@ -132,12 +139,12 @@ func TestTLSConfigKey(t *testing.T) { } for nameA, valueA := range uniqueConfigurations { for nameB, valueB := range uniqueConfigurations { - keyA, err := tlsConfigKey(valueA) + keyA, canCacheA, err := tlsConfigKey(valueA) if err != nil { t.Errorf("Unexpected error for %q: %v", nameA, err) continue } - keyB, err := tlsConfigKey(valueB) + keyB, canCacheB, err := tlsConfigKey(valueB) if err != nil { t.Errorf("Unexpected error for %q: %v", nameB, err) continue @@ -148,33 +155,18 @@ func TestTLSConfigKey(t *testing.T) { if keyA != keyB { t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) } + if canCacheA != canCacheB { + t.Errorf("Expected identical canCache %q and %q, got:\n\t%v\n\t%v", nameA, nameB, canCacheA, canCacheB) + } continue } - if keyA == keyB { - t.Errorf("Expected unique cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) - continue + if canCacheA && canCacheB { + if keyA == keyB { + t.Errorf("Expected unique cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) + continue + } } } } } - -func TestTLSConfigKeyFuncPtr(t *testing.T) { - keys := make(map[tlsCacheKey]struct{}) - makeKey := func(p func(*http.Request) (*url.URL, error)) tlsCacheKey { - key, err := tlsConfigKey(&Config{Proxy: p}) - if err != nil { - t.Fatalf("Unexpected error creating cache key: %v", err) - } - return key - } - - keys[makeKey(http.ProxyFromEnvironment)] = struct{}{} - keys[makeKey(http.ProxyFromEnvironment)] = struct{}{} - keys[makeKey(http.ProxyURL(nil))] = struct{}{} - keys[makeKey(nil)] = struct{}{} - - if got, want := len(keys), 3; got != want { - t.Fatalf("Unexpected number of keys: got=%d want=%d", got, want) - } -}