Merge pull request #136355 from enj/enj/i/tls_cache_gc

Add GC to client-go TLS cache

Kubernetes-commit: bf1abbf2e987883ffacaaf6f84218bcbfd444876
This commit is contained in:
Kubernetes Publisher
2026-03-19 05:42:29 +05:30
6 changed files with 774 additions and 62 deletions

View File

@@ -55,6 +55,13 @@ const (
// "application/json" or "application/apply-patch+yaml", respectively.
ClientsAllowCBOR Feature = "ClientsAllowCBOR"
// owner: @enj
// beta: v1.36
//
// If enabled, the client-go TLS transport cache uses weak pointers to allow
// garbage collection of unused transports, preventing unbounded cache growth.
ClientsAllowTLSCacheGC Feature = "ClientsAllowTLSCacheGC"
// owner: @benluddy
// kep: https://kep.k8s.io/4222
// alpha: 1.32
@@ -111,6 +118,9 @@ var defaultVersionedKubernetesFeatureGates = map[Feature]VersionedSpecs{
ClientsAllowCBOR: {
{Version: version.MustParse("1.32"), Default: false, PreRelease: Alpha},
},
ClientsAllowTLSCacheGC: {
{Version: version.MustParse("1.36"), Default: true, PreRelease: Beta},
},
ClientsPreferCBOR: {
{Version: version.MustParse("1.32"), Default: false, PreRelease: Alpha},
},

View File

@@ -80,7 +80,7 @@ type TransportCacheMetric interface {
}
// TransportCreateCallsMetric counts the number of times a transport is created
// partitioned by the result of the cache: hit, miss, uncacheable
// partitioned by the result of the cache: hit, miss, miss-gc, uncacheable
type TransportCreateCallsMetric interface {
Increment(result string)
}
@@ -91,6 +91,18 @@ type TransportCAReloadsMetric interface {
Increment(result, reason string)
}
// TransportCertRotationGCCallsMetric counts the number of times a cert rotation
// goroutine cancel func is called via GC cleanup.
type TransportCertRotationGCCallsMetric interface {
Increment()
}
// TransportCacheGCCallsMetric counts the number of times a GC cleanup
// attempts to delete a cache entry, partitioned by the result: deleted, skipped.
type TransportCacheGCCallsMetric interface {
Increment(result string)
}
var (
// ClientCertExpiry is the expiry time of a client certificate
ClientCertExpiry ExpiryMetric = noopExpiry{}
@@ -123,27 +135,34 @@ var (
// TransportCreateCalls is the metric that counts the number of times a new transport
// is created
TransportCreateCalls TransportCreateCallsMetric = noopTransportCreateCalls{}
// TransportCAReloads is the metric that counts the number of times a CA reload is attempted
TransportCAReloads TransportCAReloadsMetric = noopTransportCAReloads{}
// TransportCertRotationGCCalls counts the number of times a cert rotation goroutine
// cancel func is called via GC cleanup
TransportCertRotationGCCalls TransportCertRotationGCCallsMetric = noopTransportCertRotationGCCalls{}
// TransportCacheGCCalls counts the number of times a GC cleanup attempts
// to delete a transport cache entry, partitioned by result: deleted, skipped.
TransportCacheGCCalls TransportCacheGCCallsMetric = noopTransportCacheGCCalls{}
)
// RegisterOpts contains all the metrics to register. Metrics may be nil.
type RegisterOpts struct {
ClientCertExpiry ExpiryMetric
ClientCertRotationAge DurationMetric
RequestLatency LatencyMetric
ResolverLatency ResolverLatencyMetric
RequestSize SizeMetric
ResponseSize SizeMetric
RateLimiterLatency LatencyMetric
RequestResult ResultMetric
ExecPluginCalls CallsMetric
ExecPluginPolicyCalls PolicyCallsMetric
RequestRetry RetryMetric
TransportCacheEntries TransportCacheMetric
TransportCreateCalls TransportCreateCallsMetric
TransportCAReloads TransportCAReloadsMetric
ClientCertExpiry ExpiryMetric
ClientCertRotationAge DurationMetric
RequestLatency LatencyMetric
ResolverLatency ResolverLatencyMetric
RequestSize SizeMetric
ResponseSize SizeMetric
RateLimiterLatency LatencyMetric
RequestResult ResultMetric
ExecPluginCalls CallsMetric
ExecPluginPolicyCalls PolicyCallsMetric
RequestRetry RetryMetric
TransportCacheEntries TransportCacheMetric
TransportCreateCalls TransportCreateCallsMetric
TransportCAReloads TransportCAReloadsMetric
TransportCertRotationGCCalls TransportCertRotationGCCallsMetric
TransportCacheGCCalls TransportCacheGCCallsMetric
}
// Register registers metrics for the rest client to use. This can
@@ -192,6 +211,12 @@ func Register(opts RegisterOpts) {
if opts.TransportCAReloads != nil {
TransportCAReloads = opts.TransportCAReloads
}
if opts.TransportCertRotationGCCalls != nil {
TransportCertRotationGCCalls = opts.TransportCertRotationGCCalls
}
if opts.TransportCacheGCCalls != nil {
TransportCacheGCCalls = opts.TransportCacheGCCalls
}
})
}
@@ -243,3 +268,11 @@ func (noopTransportCreateCalls) Increment(string) {}
type noopTransportCAReloads struct{}
func (noopTransportCAReloads) Increment(result, reason string) {}
type noopTransportCertRotationGCCalls struct{}
func (noopTransportCertRotationGCCalls) Increment() {}
type noopTransportCacheGCCalls struct{}
func (noopTransportCacheGCCalls) Increment(string) {}

View File

@@ -252,7 +252,7 @@ func TestCARotationConnectionBehavior(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create transport: %v", err)
}
transport.(*atomicTransportHolder).caRefreshDuration = 500 * time.Millisecond
transport.(*trackedTransport).rt.(*atomicTransportHolder).caRefreshDuration = 500 * time.Millisecond
client := &http.Client{
Transport: transport,

View File

@@ -21,12 +21,14 @@ import (
"fmt"
"net"
"net/http"
"runtime"
"strings"
"sync"
"time"
"weak"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/pkg/util/wait"
clientgofeaturegate "k8s.io/client-go/features"
"k8s.io/client-go/tools/metrics"
"k8s.io/klog/v2"
)
@@ -35,19 +37,20 @@ import (
// same RoundTripper will be returned for configs with identical TLS options If
// the config has no custom TLS options, http.DefaultTransport is returned.
type tlsTransportCache struct {
mu sync.Mutex
transports map[tlsCacheKey]http.RoundTripper
mu sync.Mutex
transports map[tlsCacheKey]weak.Pointer[trackedTransport] // GC-enabled
strongTransports map[tlsCacheKey]http.RoundTripper // GC-disabled
}
// DialerStopCh is stop channel that is passed down to dynamic cert dialer.
// It's exposed as variable for testing purposes to avoid testing for goroutine
// leakages.
var DialerStopCh = wait.NeverStop
const idleConnsPerHost = 25
var tlsCache = &tlsTransportCache{
transports: make(map[tlsCacheKey]http.RoundTripper),
var tlsCache = newTLSCache()
func newTLSCache() *tlsTransportCache {
return &tlsTransportCache{
transports: make(map[tlsCacheKey]weak.Pointer[trackedTransport]),
strongTransports: make(map[tlsCacheKey]http.RoundTripper),
}
}
type tlsCacheKey struct {
@@ -85,14 +88,18 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
// Ensure we only create a single transport for the given TLS options
c.mu.Lock()
defer c.mu.Unlock()
defer metrics.TransportCacheEntries.Observe(len(c.transports))
defer func() { metrics.TransportCacheEntries.Observe(c.lenLocked()) }()
// See if we already have a custom transport for this config
if t, ok := c.transports[key]; ok {
metrics.TransportCreateCalls.Increment("hit")
return t, nil
if t, ok := c.getLocked(key); ok {
if t != nil {
metrics.TransportCreateCalls.Increment("hit")
return t, nil
}
metrics.TransportCreateCalls.Increment("miss-gc")
} else {
metrics.TransportCreateCalls.Increment("miss")
}
metrics.TransportCreateCalls.Increment("miss")
} else {
metrics.TransportCreateCalls.Increment("uncacheable")
}
@@ -119,6 +126,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
// If we use are reloading files, we need to handle certificate rotation properly
// TODO(jackkleeman): We can also add rotation here when config.HasCertCallback() is true
var cancel context.CancelFunc
if config.TLS.ReloadTLSFiles && tlsConfig != nil && tlsConfig.GetClientCertificate != nil {
// The TLS cache is a singleton, so sharing the same name for all of its
// background activity seems okay.
@@ -126,7 +134,9 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
dynamicCertDialer := certRotatingDialer(logger, tlsConfig.GetClientCertificate, dial)
tlsConfig.GetClientCertificate = dynamicCertDialer.GetClientCertificate
dial = dynamicCertDialer.connDialer.DialContext
go dynamicCertDialer.run(DialerStopCh)
var ctx context.Context
ctx, cancel = context.WithCancel(context.Background())
go dynamicCertDialer.run(ctx.Done())
}
proxy := http.ProxyFromEnvironment
@@ -148,12 +158,95 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
transport = newAtomicTransportHolder(config.TLS.CAFile, config.TLS.CAData, httpTransport)
}
if canCache {
// Cache a single transport for these options
c.transports[key] = transport
if !canCache && cancel == nil {
return transport, nil // uncacheable config with no cert rotation - nothing to GC
}
return transport, nil
if !clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowTLSCacheGC) {
if canCache {
c.strongTransports[key] = transport
}
return transport, nil // cancel is intentionally discarded and the cert rotation go routine leaks
}
transportWithGC := &trackedTransport{rt: transport}
if cancel != nil {
// capture metric as local var so that cleanups do not influence other tests via globals
transportCertRotationGCCalls := metrics.TransportCertRotationGCCalls
runtime.AddCleanup(transportWithGC, func(_ struct{}) {
cancel()
transportCertRotationGCCalls.Increment()
}, struct{}{})
}
if canCache {
wp := weak.Make(transportWithGC)
c.transports[key] = wp
// capture metrics as local vars so that cleanups do not influence other tests via globals
transportCacheGCCalls := metrics.TransportCacheGCCalls
transportCacheEntries := metrics.TransportCacheEntries
runtime.AddCleanup(transportWithGC, func(key tlsCacheKey) {
c.mu.Lock()
defer c.mu.Unlock()
// make sure we only delete the weak pointer created by this specific setLocked call
if c.transports[key] != wp {
transportCacheGCCalls.Increment("skipped")
return
}
delete(c.transports, key)
transportCacheGCCalls.Increment("deleted")
transportCacheEntries.Observe(c.lenLocked())
}, key)
}
return transportWithGC, nil
}
func (c *tlsTransportCache) getLocked(key tlsCacheKey) (http.RoundTripper, bool) {
if !clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowTLSCacheGC) {
v, ok := c.strongTransports[key]
return v, ok
}
wp, ok := c.transports[key]
if !ok {
return nil, false
}
v := wp.Value()
if v == nil { // avoid typed nil
return nil, true // key exists but value has been garbage collected
}
return v, true
}
func (c *tlsTransportCache) lenLocked() int {
if !clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowTLSCacheGC) {
return len(c.strongTransports)
}
return len(c.transports)
}
// trackedTransport wraps an http.RoundTripper to serve as the weak.Pointer
// target in the TLS transport cache. Dropping all references to this object
// triggers GC cleanup of the cache entry and any cert rotation goroutine.
type trackedTransport struct {
rt http.RoundTripper
}
var _ http.RoundTripper = &trackedTransport{}
var _ utilnet.RoundTripperWrapper = &trackedTransport{}
func (v *trackedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return v.rt.RoundTrip(req)
}
func (v *trackedTransport) WrappedRoundTripper() http.RoundTripper {
return v.rt
}
// tlsConfigKey returns a unique key for tls.Config objects returned from TLSConfigFor

View File

@@ -25,11 +25,17 @@ import (
"net/url"
"os"
"reflect"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"weak"
"k8s.io/apimachinery/pkg/util/wait"
clientgofeaturegate "k8s.io/client-go/features"
clientfeaturestesting "k8s.io/client-go/features/testing"
"k8s.io/client-go/tools/metrics"
)
func TestTLSConfigKey(t *testing.T) {
@@ -237,7 +243,6 @@ func TestTLSConfigKeyCARotationDisabled(t *testing.T) {
// TestTLSTransportCacheCARotation tests transport cache behavior with CA rotation
func TestTLSTransportCacheCARotation(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
caFile := writeCAFile(t, []byte(testCACert1))
testCases := []struct {
@@ -289,10 +294,8 @@ func TestTLSTransportCacheCARotation(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create new cache for testing
tlsCaches := &tlsTransportCache{
transports: make(map[tlsCacheKey]http.RoundTripper),
}
createCalls, _, _, _ := installFakeMetrics(t)
tlsCaches := newTLSCache()
rt, err := tlsCaches.get(tc.config)
if err != nil {
@@ -304,10 +307,33 @@ func TestTLSTransportCacheCARotation(t *testing.T) {
if rt != http.DefaultTransport {
t.Errorf("Expected default transport, got %T", rt)
}
// canCache=true so getLocked fires "miss", but DefaultTransport
// is returned before setLocked — the entry is never stored.
if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "miss" {
t.Errorf("expected [miss], got %v", calls)
}
// A second call is also "miss" — nothing was stored.
rt2, err := tlsCaches.get(tc.config)
if err != nil {
t.Fatal(err)
}
if rt2 != http.DefaultTransport {
t.Errorf("second call: expected default transport, got %T", rt2)
}
if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "miss" {
t.Errorf("expected [miss] on second call (nothing stored), got %v", calls)
}
requireCacheLen(t, tlsCaches, 0)
return
}
if tc.expectWrapper {
if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "miss" {
t.Errorf("expected [miss] on first get, got %v", calls)
}
if rt := rt.(*trackedTransport).rt; tc.expectWrapper {
// Should be wrapped in atomicTransportHolder
if _, ok := rt.(*atomicTransportHolder); !ok {
t.Errorf("Expected atomicTransportHolder, got %T", rt)
@@ -327,24 +353,14 @@ func TestTLSTransportCacheCARotation(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error on second call: %v", err)
}
if rt != rt2 {
t.Error("Expected same transport instance from cache")
}
// Verify cache size
tlsCaches.mu.Lock()
cacheSize := len(tlsCaches.transports)
tlsCaches.mu.Unlock()
expectedCacheSize := 1
if !tc.expectCacheable {
expectedCacheSize = 0
if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "hit" {
t.Errorf("expected [hit] on second get, got %v", calls)
}
if cacheSize != expectedCacheSize {
t.Errorf("Expected %d transports in cache, got %d", expectedCacheSize, cacheSize)
}
requireCacheLen(t, tlsCaches, 1)
})
}
}
@@ -353,13 +369,14 @@ func TestTLSTransportCacheCARotationDisabled(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false)
caFile := writeCAFile(t, []byte(testCACert1))
cache := &tlsTransportCache{transports: make(map[tlsCacheKey]http.RoundTripper)}
cache := newTLSCache()
rt, err := cache.get(&Config{TLS: TLSConfig{CAFile: caFile}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
rt = rt.(*trackedTransport).rt
if _, ok := rt.(*atomicTransportHolder); ok {
t.Error("Expected plain *http.Transport when feature gate is disabled, got atomicTransportHolder")
}
@@ -381,9 +398,7 @@ func TestEmptyCAFileRotationLifecycle(t *testing.T) {
},
}
tlsCaches := &tlsTransportCache{
transports: make(map[tlsCacheKey]http.RoundTripper),
}
tlsCaches := newTLSCache()
rt, err := tlsCaches.get(config)
if err != nil {
@@ -391,7 +406,7 @@ func TestEmptyCAFileRotationLifecycle(t *testing.T) {
}
// Verify newAtomicTransportHolder is successfully generated
holder, ok := rt.(*atomicTransportHolder)
holder, ok := rt.(*trackedTransport).rt.(*atomicTransportHolder)
if !ok {
t.Fatalf("Expected atomicTransportHolder, got %T", rt)
}
@@ -424,3 +439,561 @@ func TestEmptyCAFileRotationLifecycle(t *testing.T) {
t.Fatal("Expected RootCAs to be populated after writing valid cert data and refreshing")
}
}
// TestCacheHoldAfterCARotation verifies that holding the *atomicTransportHolder
// keeps the cache entry alive even after CA rotation swaps the inner transport.
func TestCacheHoldAfterCARotation(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true)
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
_, gcCalls, _, _ := installFakeMetrics(t)
caFile := writeCAFile(t, []byte(testCACert1))
cache := newTLSCache()
rt, err := cache.get(&Config{TLS: TLSConfig{ServerName: "reload-test", CAFile: caFile}})
if err != nil {
t.Fatal(err)
}
requireCacheLen(t, cache, 1)
holder := rt.(*trackedTransport).rt.(*atomicTransportHolder)
originalInner := holder.getTransport(context.Background())
if originalInner == nil {
t.Fatal("expected non-nil transport")
}
// Simulate CA rotation.
if err := os.WriteFile(caFile, []byte(testCACert2), 0644); err != nil {
t.Fatal(err)
}
holder.mu.Lock()
holder.transportLastChecked = time.Now().Add(-time.Hour)
holder.mu.Unlock()
newInner := holder.getTransport(context.Background())
if newInner == nil {
t.Fatal("expected non-nil transport after rotation")
}
if newInner == originalInner {
t.Fatal("expected transport to change after CA rotation")
}
// Cache entry must survive because the holder is alive.
for range 3 {
runtime.GC()
}
requireCacheLen(t, cache, 1)
runtime.KeepAlive(rt)
pollCacheSizeWithGC(t, cache, 0)
if calls := gcCalls.reset(); len(calls) != 1 || calls[0] != "deleted" {
t.Errorf("expected [deleted] after eviction, got %v", calls)
}
}
// TestCacheGCDisabledNoEviction verifies that with the GC feature gate disabled,
// cache entries are stored in strongTransports and never evicted.
func TestCacheGCDisabledNoEviction(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, false)
createCalls, gcCalls, _, cacheEntries := installFakeMetrics(t)
cache := newTLSCache()
rt, err := cache.get(&Config{TLS: TLSConfig{ServerName: "gc-disabled-test"}})
if err != nil {
t.Fatal(err)
}
if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "miss" {
t.Errorf("expected [miss], got %v", calls)
}
if _, ok := rt.(*trackedTransport); ok {
t.Error("expected plain transport, not *trackedTransport, when GC is disabled")
}
rt2, err := cache.get(&Config{TLS: TLSConfig{ServerName: "gc-disabled-test"}})
if err != nil {
t.Fatal(err)
}
if rt != rt2 {
t.Error("expected cache hit to return same transport")
}
if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "hit" {
t.Errorf("expected [hit], got %v", calls)
}
requireCacheLen(t, cache, 1)
runtime.KeepAlive(rt)
for range 10 {
runtime.GC()
}
requireCacheLen(t, cache, 1)
if calls := gcCalls.reset(); len(calls) != 0 {
t.Errorf("expected no GC cache calls when GC is disabled, got %v", calls)
}
// Both get() calls observe 1 — the defer evaluates lenLocked() after setLocked stores.
if values := cacheEntries.reset(); len(values) != 2 || values[0] != 1 || values[1] != 1 {
t.Errorf("expected cache entries observations [1 1], got %v", values)
}
}
// TestCacheReviveAfterDrop verifies that when rt1 is still alive (via KeepAlive),
// a second get() for the same key is a deterministic cache hit.
func TestCacheReviveAfterDrop(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true)
createCalls, _, _, _ := installFakeMetrics(t)
cache := newTLSCache()
config := &Config{TLS: TLSConfig{ServerName: "keep-alive-test"}}
rt1, err := cache.get(config)
if err != nil {
t.Fatal(err)
}
requireCacheLen(t, cache, 1)
createCalls.reset()
// rt1 is alive here, so the weak pointer must still resolve.
rt2, err := cache.get(config)
if err != nil {
t.Fatal(err)
}
// rt1 and rt2 are the same object (cache hit).
if rt1 != rt2 {
t.Error("expected same transport (cache hit) since rt1 is still alive")
}
if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "hit" {
t.Errorf("expected [hit], got %v", calls)
}
requireCacheLen(t, cache, 1)
runtime.KeepAlive(rt1)
pollCacheSizeWithGC(t, cache, 0)
}
// TestUncacheableCertRotationLeak verifies that the cert rotation goroutine
// is stopped when an uncacheable transport is garbage collected.
func TestUncacheableCertRotationLeak(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true)
_, _, rotationGCCalls, _ := installFakeMetrics(t)
certFile := writeCAFile(t, []byte(certData))
keyFile := writeCAFile(t, []byte(keyData))
cache := newTLSCache()
baseline := runtime.NumGoroutine()
rt, err := cache.get(&Config{
TLS: TLSConfig{
CertFile: certFile,
KeyFile: keyFile,
},
Proxy: func(*http.Request) (*url.URL, error) { return nil, nil }, // force uncacheable
})
if err != nil {
t.Fatal(err)
}
requireCacheLen(t, cache, 0)
afterCreate := runtime.NumGoroutine()
if afterCreate <= baseline {
t.Fatalf("expected goroutine count to increase after creating transport with cert rotation, got baseline=%d after=%d", baseline, afterCreate)
}
runtime.KeepAlive(rt)
err = wait.PollUntilContextTimeout(t.Context(), 10*time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) {
runtime.GC()
return runtime.NumGoroutine() <= baseline, nil
})
if err != nil {
t.Errorf("goroutine leak: cert rotation goroutine was not stopped for uncacheable transport (baseline=%d current=%d)", baseline, runtime.NumGoroutine())
}
if n := rotationGCCalls.count.Load(); n != 1 {
t.Errorf("expected TransportCertRotationGCCalls=1, got %d", n)
}
}
// TestCacheableCertRotationLeak verifies that the cert rotation goroutine
// is stopped when a cacheable transport with cert rotation is garbage collected.
// This exercises the canCache=true && cancel != nil path through setLocked.
func TestCacheableCertRotationLeak(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true)
_, gcCalls, rotationGCCalls, _ := installFakeMetrics(t)
certFile := writeCAFile(t, []byte(certData))
keyFile := writeCAFile(t, []byte(keyData))
cache := newTLSCache()
baseline := runtime.NumGoroutine()
// CertFile+KeyFile triggers ReloadTLSFiles and the cert rotation goroutine.
// No Proxy means canCache=true.
rt, err := cache.get(&Config{
TLS: TLSConfig{
CertFile: certFile,
KeyFile: keyFile,
ServerName: "cacheable-cert-rotation-test",
},
})
if err != nil {
t.Fatal(err)
}
requireCacheLen(t, cache, 1)
afterCreate := runtime.NumGoroutine()
if afterCreate <= baseline {
t.Fatalf("expected goroutine count to increase after creating transport with cert rotation, got baseline=%d after=%d", baseline, afterCreate)
}
runtime.KeepAlive(rt)
gcCalls.reset()
pollCacheSizeWithGC(t, cache, 0)
err = wait.PollUntilContextTimeout(t.Context(), 10*time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) {
runtime.GC()
return runtime.NumGoroutine() <= baseline, nil
})
if err != nil {
t.Errorf("goroutine leak: cert rotation goroutine was not stopped for cacheable transport (baseline=%d current=%d)", baseline, runtime.NumGoroutine())
}
if n := rotationGCCalls.count.Load(); n != 1 {
t.Errorf("expected TransportCertRotationGCCalls=1, got %d", n)
}
if calls := gcCalls.reset(); len(calls) != 1 || calls[0] != "deleted" {
t.Errorf("expected [deleted], got %v", calls)
}
}
// TestCacheGCDisabledCertRotationNoCancel verifies that with the GC feature gate
// disabled, the cert rotation goroutine is NOT stopped (pre-GC behavior).
func TestCacheGCDisabledCertRotationNoCancel(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, false)
_, _, rotationGCCalls, _ := installFakeMetrics(t)
certFile := writeCAFile(t, []byte(certData))
keyFile := writeCAFile(t, []byte(keyData))
cache := newTLSCache()
baseline := runtime.NumGoroutine()
rt, err := cache.get(&Config{
TLS: TLSConfig{
CertFile: certFile,
KeyFile: keyFile,
ServerName: "gc-disabled-cert-rotation",
},
})
if err != nil {
t.Fatal(err)
}
afterCreate := runtime.NumGoroutine()
if afterCreate <= baseline {
t.Fatalf("expected cert rotation goroutine to start, got baseline=%d after=%d", baseline, afterCreate)
}
// Drop rt after this point
// if GC cleanup were enabled, cancel would fire and the goroutine would stop.
runtime.KeepAlive(rt)
for range 10 {
runtime.GC()
runtime.Gosched()
}
time.Sleep(50 * time.Millisecond) // let any potential cleanup callbacks run
if n := rotationGCCalls.count.Load(); n != 0 {
t.Errorf("expected TransportCertRotationGCCalls=0 when GC disabled, got %d", n)
}
// Goroutine must still be running — cancel was intentionally discarded.
if runtime.NumGoroutine() <= baseline {
t.Error("cert rotation goroutine was stopped — it should run indefinitely when GC is disabled")
}
}
// TestCacheStaleEvictionSkipped verifies that when a GC cleanup fires for an
// old entry after the map slot has been replaced, the deletion is skipped.
// We force this by replacing the map entry while the old trackedTransport
// is still alive, then dropping it so its cleanup fires against the new entry.
func TestCacheStaleEvictionSkipped(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true)
_, gcCalls, _, _ := installFakeMetrics(t)
cache := newTLSCache()
cfg := &Config{TLS: TLSConfig{ServerName: "stale-test"}}
rt1, err := cache.get(cfg)
if err != nil {
t.Fatal(err)
}
requireCacheLen(t, cache, 1)
// Replace the map entry while rt1 is still alive. This simulates a
// concurrent get() that created a new entry for the same key.
key, _, _ := tlsConfigKey(cfg)
cache.mu.Lock()
replacement := &trackedTransport{rt: &http.Transport{}}
cache.transports[key] = weak.Make(replacement)
cache.mu.Unlock()
// Drop rt1 after this point. Its cleanup fires and sees c.transports[key] != wp -> "skipped".
runtime.KeepAlive(rt1)
gcCalls.reset()
err = wait.PollUntilContextTimeout(t.Context(), 10*time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) {
runtime.GC()
gcCalls.mu.Lock()
n := len(gcCalls.calls)
gcCalls.mu.Unlock()
return n > 0, nil
})
if err != nil {
t.Fatal("timed out waiting for stale cleanup to fire")
}
calls := gcCalls.reset()
if len(calls) != 1 || calls[0] != "skipped" {
t.Errorf("expected [skipped], got %v", calls)
}
// The replacement entry must still be in the map because of the keep alive
requireCacheLen(t, cache, 1)
runtime.KeepAlive(replacement)
}
// TestCacheLeak does a black box test of TLS cache GC from New.
// It uses a large number of cache entries generated in parallel.
func TestCacheLeak(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true)
_, gcCalls, _, cacheEntries := installFakeMetrics(t)
pollCacheSizeWithGC(t, tlsCache, 0) // clean start with global cache
rt1, err := New(&Config{TLS: TLSConfig{ServerName: "1"}})
if err != nil {
t.Fatal(err)
}
rt2, err := New(&Config{TLS: TLSConfig{ServerName: "2"}})
if err != nil {
t.Fatal(err)
}
rt3, err := New(&Config{TLS: TLSConfig{ServerName: "1"}})
if err != nil {
t.Fatal(err)
}
requireCacheLen(t, tlsCache, 2) // rt1 and rt2 (rt3 is the same as rt1)
var wg wait.Group
var d net.Dialer
var rts []http.RoundTripper
var rtsLock sync.Mutex
for i := range 1_000 { // outer loop forces cache miss via dialer
dh := &DialHolder{Dial: d.DialContext}
for range i%7 + 1 { // inner loop exercises each cache value having 1 to N references
wg.Start(func() {
rt, err := New(&Config{DialHolder: dh})
if err != nil {
panic(err)
}
rtsLock.Lock()
rts = append(rts, rt) // keep a live reference to the round tripper
rtsLock.Unlock()
})
}
}
wg.Wait()
requireCacheLen(t, tlsCache, 1_000+2) // rts and rt1 and rt2 (rt3 is the same as rt1)
runtime.KeepAlive(rts) // prevent round trippers from being GC'd too early
// Reset before the eviction we want to measure.
gcCalls.reset()
cacheEntries.reset()
pollCacheSizeWithGC(t, tlsCache, 2) // rt1 and rt2 (rt3 is the same as rt1)
calls := gcCalls.reset()
deletedCount := 0
for _, c := range calls {
if c == "deleted" {
deletedCount++
} else {
t.Errorf("expected deleted call, got %s", c)
}
}
if deletedCount != 1_000 {
t.Errorf("expected 1000 deleted calls, got %d (total calls: %d)", deletedCount, len(calls))
}
// Each "deleted" cleanup calls Observe(lenLocked()). The last observation
// should be 2 (rt1 and rt2 remain).
evictionObservations := cacheEntries.reset()
if len(evictionObservations) != 1000 {
t.Errorf("expected exactly 1000 cache entries observations from GC cleanup, got %d", len(evictionObservations))
} else if last := evictionObservations[len(evictionObservations)-1]; last != 2 {
t.Errorf("expected last cache entries observation to be 2 (rt1+rt2 remain), got %d", last)
}
runtime.KeepAlive(rt1)
runtime.KeepAlive(rt2)
runtime.KeepAlive(rt3)
pollCacheSizeWithGC(t, tlsCache, 0)
calls = gcCalls.reset()
if len(calls) != 2 || calls[0] != "deleted" || calls[1] != "deleted" {
t.Errorf("expected 2 deleted calls for rt1/rt2, got %v", calls)
}
}
func requireCacheLen(t *testing.T, c *tlsTransportCache, want int) {
t.Helper()
if cacheLen(t, c) != want {
t.Fatalf("cache len %d, want %d", cacheLen(t, c), want)
}
}
func cacheLen(t *testing.T, c *tlsTransportCache) int {
t.Helper()
c.mu.Lock()
defer c.mu.Unlock()
if !clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowTLSCacheGC) {
if len(c.transports) != 0 {
t.Fatalf("transports len %d must be 0", len(c.transports))
}
return len(c.strongTransports)
}
if len(c.strongTransports) != 0 {
t.Fatalf("strongTransports len %d must be 0", len(c.strongTransports))
}
return len(c.transports)
}
func pollCacheSizeWithGC(t *testing.T, c *tlsTransportCache, want int) {
t.Helper()
if err := wait.PollUntilContextTimeout(t.Context(), 10*time.Millisecond, 10*time.Second, true, func(_ context.Context) (done bool, _ error) {
runtime.GC() // run the garbage collector so the cleanups run
return cacheLen(t, c) == want, nil
}); err != nil {
t.Fatalf("cache len %d, want %d: %v", cacheLen(t, c), want, err)
}
// make sure the cache size is stable even when more GC's happen
// three times should be enough to make the test flake if the implementation is buggy
for range 3 {
runtime.GC()
}
requireCacheLen(t, c, want)
}
type recordingCreateCalls struct {
mu sync.Mutex
calls []string
}
func (r *recordingCreateCalls) Increment(result string) {
r.mu.Lock()
defer r.mu.Unlock()
r.calls = append(r.calls, result)
}
func (r *recordingCreateCalls) reset() []string {
r.mu.Lock()
defer r.mu.Unlock()
c := r.calls
r.calls = nil
return c
}
type recordingCacheGCCalls struct {
mu sync.Mutex
calls []string
}
func (r *recordingCacheGCCalls) Increment(result string) {
r.mu.Lock()
defer r.mu.Unlock()
r.calls = append(r.calls, result)
}
func (r *recordingCacheGCCalls) reset() []string {
r.mu.Lock()
defer r.mu.Unlock()
c := r.calls
r.calls = nil
return c
}
type recordingCertRotationGCCalls struct {
count atomic.Int64
}
func (r *recordingCertRotationGCCalls) Increment() {
r.count.Add(1)
}
type recordingCacheEntries struct {
mu sync.Mutex
values []int
}
func (r *recordingCacheEntries) Observe(n int) {
r.mu.Lock()
defer r.mu.Unlock()
r.values = append(r.values, n)
}
func (r *recordingCacheEntries) reset() []int {
r.mu.Lock()
defer r.mu.Unlock()
v := r.values
r.values = nil
return v
}
func installFakeMetrics(t *testing.T) (*recordingCreateCalls, *recordingCacheGCCalls, *recordingCertRotationGCCalls, *recordingCacheEntries) {
createCalls := &recordingCreateCalls{}
gcCalls := &recordingCacheGCCalls{}
rotationGCCalls := &recordingCertRotationGCCalls{}
cacheEntries := &recordingCacheEntries{}
origCreate := metrics.TransportCreateCalls
origGC := metrics.TransportCacheGCCalls
origRotationGC := metrics.TransportCertRotationGCCalls
origEntries := metrics.TransportCacheEntries
metrics.TransportCreateCalls = createCalls
metrics.TransportCacheGCCalls = gcCalls
metrics.TransportCertRotationGCCalls = rotationGCCalls
metrics.TransportCacheEntries = cacheEntries
t.Cleanup(func() {
metrics.TransportCreateCalls = origCreate
metrics.TransportCacheGCCalls = origGC
metrics.TransportCertRotationGCCalls = origRotationGC
metrics.TransportCacheEntries = origEntries
})
return createCalls, gcCalls, rotationGCCalls, cacheEntries
}

View File

@@ -389,7 +389,10 @@ func TestNew(t *testing.T) {
}
// We only know how to check TLSConfig on http.Transports
transport := rt.(*http.Transport)
transport, ok := rt.(*http.Transport)
if !ok {
transport = rt.(*trackedTransport).rt.(*http.Transport)
}
switch {
case testCase.TLS && transport.TLSClientConfig == nil:
t.Fatalf("got %#v, expected TLSClientConfig", transport)