From f5fc1e5f67e9f487e6e1b96384c16788e1a9db9a Mon Sep 17 00:00:00 2001 From: Monis Khan Date: Fri, 16 Jan 2026 14:58:52 -0500 Subject: [PATCH] Add GC to client-go TLS cache Signed-off-by: Monis Khan Kubernetes-commit: fa9a1fe5f7084c0a1371d87e880c1a9e58f935a4 --- features/known_features.go | 10 + tools/metrics/metrics.go | 65 +++- transport/ca_rotation_test.go | 2 +- transport/cache.go | 133 ++++++-- transport/cache_test.go | 621 ++++++++++++++++++++++++++++++++-- transport/transport_test.go | 5 +- 6 files changed, 774 insertions(+), 62 deletions(-) diff --git a/features/known_features.go b/features/known_features.go index 6cf49c424..250188dbc 100644 --- a/features/known_features.go +++ b/features/known_features.go @@ -55,6 +55,13 @@ const ( // "application/json" or "application/apply-patch+yaml", respectively. ClientsAllowCBOR Feature = "ClientsAllowCBOR" + // owner: @enj + // beta: v1.36 + // + // If enabled, the client-go TLS transport cache uses weak pointers to allow + // garbage collection of unused transports, preventing unbounded cache growth. + ClientsAllowTLSCacheGC Feature = "ClientsAllowTLSCacheGC" + // owner: @benluddy // kep: https://kep.k8s.io/4222 // alpha: 1.32 @@ -111,6 +118,9 @@ var defaultVersionedKubernetesFeatureGates = map[Feature]VersionedSpecs{ ClientsAllowCBOR: { {Version: version.MustParse("1.32"), Default: false, PreRelease: Alpha}, }, + ClientsAllowTLSCacheGC: { + {Version: version.MustParse("1.36"), Default: true, PreRelease: Beta}, + }, ClientsPreferCBOR: { {Version: version.MustParse("1.32"), Default: false, PreRelease: Alpha}, }, diff --git a/tools/metrics/metrics.go b/tools/metrics/metrics.go index 789c3d0a5..2f626475b 100644 --- a/tools/metrics/metrics.go +++ b/tools/metrics/metrics.go @@ -80,7 +80,7 @@ type TransportCacheMetric interface { } // TransportCreateCallsMetric counts the number of times a transport is created -// partitioned by the result of the cache: hit, miss, uncacheable +// partitioned by the result of the cache: hit, miss, miss-gc, uncacheable type TransportCreateCallsMetric interface { Increment(result string) } @@ -91,6 +91,18 @@ type TransportCAReloadsMetric interface { Increment(result, reason string) } +// TransportCertRotationGCCallsMetric counts the number of times a cert rotation +// goroutine cancel func is called via GC cleanup. +type TransportCertRotationGCCallsMetric interface { + Increment() +} + +// TransportCacheGCCallsMetric counts the number of times a GC cleanup +// attempts to delete a cache entry, partitioned by the result: deleted, skipped. +type TransportCacheGCCallsMetric interface { + Increment(result string) +} + var ( // ClientCertExpiry is the expiry time of a client certificate ClientCertExpiry ExpiryMetric = noopExpiry{} @@ -123,27 +135,34 @@ var ( // TransportCreateCalls is the metric that counts the number of times a new transport // is created TransportCreateCalls TransportCreateCallsMetric = noopTransportCreateCalls{} - // TransportCAReloads is the metric that counts the number of times a CA reload is attempted TransportCAReloads TransportCAReloadsMetric = noopTransportCAReloads{} + // TransportCertRotationGCCalls counts the number of times a cert rotation goroutine + // cancel func is called via GC cleanup + TransportCertRotationGCCalls TransportCertRotationGCCallsMetric = noopTransportCertRotationGCCalls{} + // TransportCacheGCCalls counts the number of times a GC cleanup attempts + // to delete a transport cache entry, partitioned by result: deleted, skipped. + TransportCacheGCCalls TransportCacheGCCallsMetric = noopTransportCacheGCCalls{} ) // RegisterOpts contains all the metrics to register. Metrics may be nil. type RegisterOpts struct { - ClientCertExpiry ExpiryMetric - ClientCertRotationAge DurationMetric - RequestLatency LatencyMetric - ResolverLatency ResolverLatencyMetric - RequestSize SizeMetric - ResponseSize SizeMetric - RateLimiterLatency LatencyMetric - RequestResult ResultMetric - ExecPluginCalls CallsMetric - ExecPluginPolicyCalls PolicyCallsMetric - RequestRetry RetryMetric - TransportCacheEntries TransportCacheMetric - TransportCreateCalls TransportCreateCallsMetric - TransportCAReloads TransportCAReloadsMetric + ClientCertExpiry ExpiryMetric + ClientCertRotationAge DurationMetric + RequestLatency LatencyMetric + ResolverLatency ResolverLatencyMetric + RequestSize SizeMetric + ResponseSize SizeMetric + RateLimiterLatency LatencyMetric + RequestResult ResultMetric + ExecPluginCalls CallsMetric + ExecPluginPolicyCalls PolicyCallsMetric + RequestRetry RetryMetric + TransportCacheEntries TransportCacheMetric + TransportCreateCalls TransportCreateCallsMetric + TransportCAReloads TransportCAReloadsMetric + TransportCertRotationGCCalls TransportCertRotationGCCallsMetric + TransportCacheGCCalls TransportCacheGCCallsMetric } // Register registers metrics for the rest client to use. This can @@ -192,6 +211,12 @@ func Register(opts RegisterOpts) { if opts.TransportCAReloads != nil { TransportCAReloads = opts.TransportCAReloads } + if opts.TransportCertRotationGCCalls != nil { + TransportCertRotationGCCalls = opts.TransportCertRotationGCCalls + } + if opts.TransportCacheGCCalls != nil { + TransportCacheGCCalls = opts.TransportCacheGCCalls + } }) } @@ -243,3 +268,11 @@ func (noopTransportCreateCalls) Increment(string) {} type noopTransportCAReloads struct{} func (noopTransportCAReloads) Increment(result, reason string) {} + +type noopTransportCertRotationGCCalls struct{} + +func (noopTransportCertRotationGCCalls) Increment() {} + +type noopTransportCacheGCCalls struct{} + +func (noopTransportCacheGCCalls) Increment(string) {} diff --git a/transport/ca_rotation_test.go b/transport/ca_rotation_test.go index f4aab8ae2..5bfa64516 100644 --- a/transport/ca_rotation_test.go +++ b/transport/ca_rotation_test.go @@ -252,7 +252,7 @@ func TestCARotationConnectionBehavior(t *testing.T) { if err != nil { t.Fatalf("Failed to create transport: %v", err) } - transport.(*atomicTransportHolder).caRefreshDuration = 500 * time.Millisecond + transport.(*trackedTransport).rt.(*atomicTransportHolder).caRefreshDuration = 500 * time.Millisecond client := &http.Client{ Transport: transport, diff --git a/transport/cache.go b/transport/cache.go index ae1c39ffe..16b62ec9f 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -21,12 +21,14 @@ import ( "fmt" "net" "net/http" + "runtime" "strings" "sync" "time" + "weak" utilnet "k8s.io/apimachinery/pkg/util/net" - "k8s.io/apimachinery/pkg/util/wait" + clientgofeaturegate "k8s.io/client-go/features" "k8s.io/client-go/tools/metrics" "k8s.io/klog/v2" ) @@ -35,19 +37,20 @@ import ( // same RoundTripper will be returned for configs with identical TLS options If // the config has no custom TLS options, http.DefaultTransport is returned. type tlsTransportCache struct { - mu sync.Mutex - transports map[tlsCacheKey]http.RoundTripper + mu sync.Mutex + transports map[tlsCacheKey]weak.Pointer[trackedTransport] // GC-enabled + strongTransports map[tlsCacheKey]http.RoundTripper // GC-disabled } -// DialerStopCh is stop channel that is passed down to dynamic cert dialer. -// It's exposed as variable for testing purposes to avoid testing for goroutine -// leakages. -var DialerStopCh = wait.NeverStop - const idleConnsPerHost = 25 -var tlsCache = &tlsTransportCache{ - transports: make(map[tlsCacheKey]http.RoundTripper), +var tlsCache = newTLSCache() + +func newTLSCache() *tlsTransportCache { + return &tlsTransportCache{ + transports: make(map[tlsCacheKey]weak.Pointer[trackedTransport]), + strongTransports: make(map[tlsCacheKey]http.RoundTripper), + } } type tlsCacheKey struct { @@ -85,14 +88,18 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { // Ensure we only create a single transport for the given TLS options c.mu.Lock() defer c.mu.Unlock() - defer metrics.TransportCacheEntries.Observe(len(c.transports)) + defer func() { metrics.TransportCacheEntries.Observe(c.lenLocked()) }() // See if we already have a custom transport for this config - if t, ok := c.transports[key]; ok { - metrics.TransportCreateCalls.Increment("hit") - return t, nil + if t, ok := c.getLocked(key); ok { + if t != nil { + metrics.TransportCreateCalls.Increment("hit") + return t, nil + } + metrics.TransportCreateCalls.Increment("miss-gc") + } else { + metrics.TransportCreateCalls.Increment("miss") } - metrics.TransportCreateCalls.Increment("miss") } else { metrics.TransportCreateCalls.Increment("uncacheable") } @@ -119,6 +126,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { // If we use are reloading files, we need to handle certificate rotation properly // TODO(jackkleeman): We can also add rotation here when config.HasCertCallback() is true + var cancel context.CancelFunc if config.TLS.ReloadTLSFiles && tlsConfig != nil && tlsConfig.GetClientCertificate != nil { // The TLS cache is a singleton, so sharing the same name for all of its // background activity seems okay. @@ -126,7 +134,9 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { dynamicCertDialer := certRotatingDialer(logger, tlsConfig.GetClientCertificate, dial) tlsConfig.GetClientCertificate = dynamicCertDialer.GetClientCertificate dial = dynamicCertDialer.connDialer.DialContext - go dynamicCertDialer.run(DialerStopCh) + var ctx context.Context + ctx, cancel = context.WithCancel(context.Background()) + go dynamicCertDialer.run(ctx.Done()) } proxy := http.ProxyFromEnvironment @@ -148,12 +158,95 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { transport = newAtomicTransportHolder(config.TLS.CAFile, config.TLS.CAData, httpTransport) } - if canCache { - // Cache a single transport for these options - c.transports[key] = transport + if !canCache && cancel == nil { + return transport, nil // uncacheable config with no cert rotation - nothing to GC } - return transport, nil + if !clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowTLSCacheGC) { + if canCache { + c.strongTransports[key] = transport + } + return transport, nil // cancel is intentionally discarded and the cert rotation go routine leaks + } + + transportWithGC := &trackedTransport{rt: transport} + + if cancel != nil { + // capture metric as local var so that cleanups do not influence other tests via globals + transportCertRotationGCCalls := metrics.TransportCertRotationGCCalls + runtime.AddCleanup(transportWithGC, func(_ struct{}) { + cancel() + transportCertRotationGCCalls.Increment() + }, struct{}{}) + } + + if canCache { + wp := weak.Make(transportWithGC) + c.transports[key] = wp + // capture metrics as local vars so that cleanups do not influence other tests via globals + transportCacheGCCalls := metrics.TransportCacheGCCalls + transportCacheEntries := metrics.TransportCacheEntries + runtime.AddCleanup(transportWithGC, func(key tlsCacheKey) { + c.mu.Lock() + defer c.mu.Unlock() + + // make sure we only delete the weak pointer created by this specific setLocked call + if c.transports[key] != wp { + transportCacheGCCalls.Increment("skipped") + return + } + delete(c.transports, key) + transportCacheGCCalls.Increment("deleted") + transportCacheEntries.Observe(c.lenLocked()) + }, key) + } + + return transportWithGC, nil +} + +func (c *tlsTransportCache) getLocked(key tlsCacheKey) (http.RoundTripper, bool) { + if !clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowTLSCacheGC) { + v, ok := c.strongTransports[key] + return v, ok + } + + wp, ok := c.transports[key] + if !ok { + return nil, false + } + + v := wp.Value() + + if v == nil { // avoid typed nil + return nil, true // key exists but value has been garbage collected + } + + return v, true +} + +func (c *tlsTransportCache) lenLocked() int { + if !clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowTLSCacheGC) { + return len(c.strongTransports) + } + return len(c.transports) +} + +// trackedTransport wraps an http.RoundTripper to serve as the weak.Pointer +// target in the TLS transport cache. Dropping all references to this object +// triggers GC cleanup of the cache entry and any cert rotation goroutine. +type trackedTransport struct { + rt http.RoundTripper +} + +var _ http.RoundTripper = &trackedTransport{} +var _ utilnet.RoundTripperWrapper = &trackedTransport{} + +func (v *trackedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return v.rt.RoundTrip(req) +} + +func (v *trackedTransport) WrappedRoundTripper() http.RoundTripper { + return v.rt } // tlsConfigKey returns a unique key for tls.Config objects returned from TLSConfigFor diff --git a/transport/cache_test.go b/transport/cache_test.go index e7cb74019..d2e565f58 100644 --- a/transport/cache_test.go +++ b/transport/cache_test.go @@ -25,11 +25,17 @@ import ( "net/url" "os" "reflect" + "runtime" + "sync" + "sync/atomic" "testing" "time" + "weak" + "k8s.io/apimachinery/pkg/util/wait" clientgofeaturegate "k8s.io/client-go/features" clientfeaturestesting "k8s.io/client-go/features/testing" + "k8s.io/client-go/tools/metrics" ) func TestTLSConfigKey(t *testing.T) { @@ -237,7 +243,6 @@ func TestTLSConfigKeyCARotationDisabled(t *testing.T) { // TestTLSTransportCacheCARotation tests transport cache behavior with CA rotation func TestTLSTransportCacheCARotation(t *testing.T) { - clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true) caFile := writeCAFile(t, []byte(testCACert1)) testCases := []struct { @@ -289,10 +294,8 @@ func TestTLSTransportCacheCARotation(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // Create new cache for testing - tlsCaches := &tlsTransportCache{ - transports: make(map[tlsCacheKey]http.RoundTripper), - } + createCalls, _, _, _ := installFakeMetrics(t) + tlsCaches := newTLSCache() rt, err := tlsCaches.get(tc.config) if err != nil { @@ -304,10 +307,33 @@ func TestTLSTransportCacheCARotation(t *testing.T) { if rt != http.DefaultTransport { t.Errorf("Expected default transport, got %T", rt) } + // canCache=true so getLocked fires "miss", but DefaultTransport + // is returned before setLocked — the entry is never stored. + if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "miss" { + t.Errorf("expected [miss], got %v", calls) + } + + // A second call is also "miss" — nothing was stored. + rt2, err := tlsCaches.get(tc.config) + if err != nil { + t.Fatal(err) + } + if rt2 != http.DefaultTransport { + t.Errorf("second call: expected default transport, got %T", rt2) + } + if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "miss" { + t.Errorf("expected [miss] on second call (nothing stored), got %v", calls) + } + + requireCacheLen(t, tlsCaches, 0) return } - if tc.expectWrapper { + if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "miss" { + t.Errorf("expected [miss] on first get, got %v", calls) + } + + if rt := rt.(*trackedTransport).rt; tc.expectWrapper { // Should be wrapped in atomicTransportHolder if _, ok := rt.(*atomicTransportHolder); !ok { t.Errorf("Expected atomicTransportHolder, got %T", rt) @@ -327,24 +353,14 @@ func TestTLSTransportCacheCARotation(t *testing.T) { if err != nil { t.Fatalf("Unexpected error on second call: %v", err) } - if rt != rt2 { t.Error("Expected same transport instance from cache") } - - // Verify cache size - tlsCaches.mu.Lock() - cacheSize := len(tlsCaches.transports) - tlsCaches.mu.Unlock() - - expectedCacheSize := 1 - if !tc.expectCacheable { - expectedCacheSize = 0 + if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "hit" { + t.Errorf("expected [hit] on second get, got %v", calls) } - if cacheSize != expectedCacheSize { - t.Errorf("Expected %d transports in cache, got %d", expectedCacheSize, cacheSize) - } + requireCacheLen(t, tlsCaches, 1) }) } } @@ -353,13 +369,14 @@ func TestTLSTransportCacheCARotationDisabled(t *testing.T) { clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false) caFile := writeCAFile(t, []byte(testCACert1)) - cache := &tlsTransportCache{transports: make(map[tlsCacheKey]http.RoundTripper)} + cache := newTLSCache() rt, err := cache.get(&Config{TLS: TLSConfig{CAFile: caFile}}) if err != nil { t.Fatalf("Unexpected error: %v", err) } + rt = rt.(*trackedTransport).rt if _, ok := rt.(*atomicTransportHolder); ok { t.Error("Expected plain *http.Transport when feature gate is disabled, got atomicTransportHolder") } @@ -381,9 +398,7 @@ func TestEmptyCAFileRotationLifecycle(t *testing.T) { }, } - tlsCaches := &tlsTransportCache{ - transports: make(map[tlsCacheKey]http.RoundTripper), - } + tlsCaches := newTLSCache() rt, err := tlsCaches.get(config) if err != nil { @@ -391,7 +406,7 @@ func TestEmptyCAFileRotationLifecycle(t *testing.T) { } // Verify newAtomicTransportHolder is successfully generated - holder, ok := rt.(*atomicTransportHolder) + holder, ok := rt.(*trackedTransport).rt.(*atomicTransportHolder) if !ok { t.Fatalf("Expected atomicTransportHolder, got %T", rt) } @@ -424,3 +439,561 @@ func TestEmptyCAFileRotationLifecycle(t *testing.T) { t.Fatal("Expected RootCAs to be populated after writing valid cert data and refreshing") } } + +// TestCacheHoldAfterCARotation verifies that holding the *atomicTransportHolder +// keeps the cache entry alive even after CA rotation swaps the inner transport. +func TestCacheHoldAfterCARotation(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true) + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true) + _, gcCalls, _, _ := installFakeMetrics(t) + + caFile := writeCAFile(t, []byte(testCACert1)) + + cache := newTLSCache() + + rt, err := cache.get(&Config{TLS: TLSConfig{ServerName: "reload-test", CAFile: caFile}}) + if err != nil { + t.Fatal(err) + } + requireCacheLen(t, cache, 1) + + holder := rt.(*trackedTransport).rt.(*atomicTransportHolder) + + originalInner := holder.getTransport(context.Background()) + if originalInner == nil { + t.Fatal("expected non-nil transport") + } + + // Simulate CA rotation. + if err := os.WriteFile(caFile, []byte(testCACert2), 0644); err != nil { + t.Fatal(err) + } + holder.mu.Lock() + holder.transportLastChecked = time.Now().Add(-time.Hour) + holder.mu.Unlock() + + newInner := holder.getTransport(context.Background()) + if newInner == nil { + t.Fatal("expected non-nil transport after rotation") + } + if newInner == originalInner { + t.Fatal("expected transport to change after CA rotation") + } + + // Cache entry must survive because the holder is alive. + for range 3 { + runtime.GC() + } + requireCacheLen(t, cache, 1) + runtime.KeepAlive(rt) + + pollCacheSizeWithGC(t, cache, 0) + if calls := gcCalls.reset(); len(calls) != 1 || calls[0] != "deleted" { + t.Errorf("expected [deleted] after eviction, got %v", calls) + } +} + +// TestCacheGCDisabledNoEviction verifies that with the GC feature gate disabled, +// cache entries are stored in strongTransports and never evicted. +func TestCacheGCDisabledNoEviction(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, false) + createCalls, gcCalls, _, cacheEntries := installFakeMetrics(t) + + cache := newTLSCache() + + rt, err := cache.get(&Config{TLS: TLSConfig{ServerName: "gc-disabled-test"}}) + if err != nil { + t.Fatal(err) + } + + if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "miss" { + t.Errorf("expected [miss], got %v", calls) + } + + if _, ok := rt.(*trackedTransport); ok { + t.Error("expected plain transport, not *trackedTransport, when GC is disabled") + } + + rt2, err := cache.get(&Config{TLS: TLSConfig{ServerName: "gc-disabled-test"}}) + if err != nil { + t.Fatal(err) + } + if rt != rt2 { + t.Error("expected cache hit to return same transport") + } + if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "hit" { + t.Errorf("expected [hit], got %v", calls) + } + + requireCacheLen(t, cache, 1) + + runtime.KeepAlive(rt) + for range 10 { + runtime.GC() + } + requireCacheLen(t, cache, 1) + + if calls := gcCalls.reset(); len(calls) != 0 { + t.Errorf("expected no GC cache calls when GC is disabled, got %v", calls) + } + + // Both get() calls observe 1 — the defer evaluates lenLocked() after setLocked stores. + if values := cacheEntries.reset(); len(values) != 2 || values[0] != 1 || values[1] != 1 { + t.Errorf("expected cache entries observations [1 1], got %v", values) + } +} + +// TestCacheReviveAfterDrop verifies that when rt1 is still alive (via KeepAlive), +// a second get() for the same key is a deterministic cache hit. +func TestCacheReviveAfterDrop(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true) + createCalls, _, _, _ := installFakeMetrics(t) + + cache := newTLSCache() + + config := &Config{TLS: TLSConfig{ServerName: "keep-alive-test"}} + + rt1, err := cache.get(config) + if err != nil { + t.Fatal(err) + } + requireCacheLen(t, cache, 1) + createCalls.reset() + + // rt1 is alive here, so the weak pointer must still resolve. + rt2, err := cache.get(config) + if err != nil { + t.Fatal(err) + } + // rt1 and rt2 are the same object (cache hit). + if rt1 != rt2 { + t.Error("expected same transport (cache hit) since rt1 is still alive") + } + if calls := createCalls.reset(); len(calls) != 1 || calls[0] != "hit" { + t.Errorf("expected [hit], got %v", calls) + } + requireCacheLen(t, cache, 1) + + runtime.KeepAlive(rt1) + pollCacheSizeWithGC(t, cache, 0) +} + +// TestUncacheableCertRotationLeak verifies that the cert rotation goroutine +// is stopped when an uncacheable transport is garbage collected. +func TestUncacheableCertRotationLeak(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true) + _, _, rotationGCCalls, _ := installFakeMetrics(t) + + certFile := writeCAFile(t, []byte(certData)) + keyFile := writeCAFile(t, []byte(keyData)) + + cache := newTLSCache() + + baseline := runtime.NumGoroutine() + + rt, err := cache.get(&Config{ + TLS: TLSConfig{ + CertFile: certFile, + KeyFile: keyFile, + }, + Proxy: func(*http.Request) (*url.URL, error) { return nil, nil }, // force uncacheable + }) + if err != nil { + t.Fatal(err) + } + + requireCacheLen(t, cache, 0) + + afterCreate := runtime.NumGoroutine() + if afterCreate <= baseline { + t.Fatalf("expected goroutine count to increase after creating transport with cert rotation, got baseline=%d after=%d", baseline, afterCreate) + } + + runtime.KeepAlive(rt) + + err = wait.PollUntilContextTimeout(t.Context(), 10*time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) { + runtime.GC() + return runtime.NumGoroutine() <= baseline, nil + }) + if err != nil { + t.Errorf("goroutine leak: cert rotation goroutine was not stopped for uncacheable transport (baseline=%d current=%d)", baseline, runtime.NumGoroutine()) + } + + if n := rotationGCCalls.count.Load(); n != 1 { + t.Errorf("expected TransportCertRotationGCCalls=1, got %d", n) + } +} + +// TestCacheableCertRotationLeak verifies that the cert rotation goroutine +// is stopped when a cacheable transport with cert rotation is garbage collected. +// This exercises the canCache=true && cancel != nil path through setLocked. +func TestCacheableCertRotationLeak(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true) + _, gcCalls, rotationGCCalls, _ := installFakeMetrics(t) + + certFile := writeCAFile(t, []byte(certData)) + keyFile := writeCAFile(t, []byte(keyData)) + + cache := newTLSCache() + + baseline := runtime.NumGoroutine() + + // CertFile+KeyFile triggers ReloadTLSFiles and the cert rotation goroutine. + // No Proxy means canCache=true. + rt, err := cache.get(&Config{ + TLS: TLSConfig{ + CertFile: certFile, + KeyFile: keyFile, + ServerName: "cacheable-cert-rotation-test", + }, + }) + if err != nil { + t.Fatal(err) + } + + requireCacheLen(t, cache, 1) + + afterCreate := runtime.NumGoroutine() + if afterCreate <= baseline { + t.Fatalf("expected goroutine count to increase after creating transport with cert rotation, got baseline=%d after=%d", baseline, afterCreate) + } + + runtime.KeepAlive(rt) + + gcCalls.reset() + pollCacheSizeWithGC(t, cache, 0) + + err = wait.PollUntilContextTimeout(t.Context(), 10*time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) { + runtime.GC() + return runtime.NumGoroutine() <= baseline, nil + }) + if err != nil { + t.Errorf("goroutine leak: cert rotation goroutine was not stopped for cacheable transport (baseline=%d current=%d)", baseline, runtime.NumGoroutine()) + } + + if n := rotationGCCalls.count.Load(); n != 1 { + t.Errorf("expected TransportCertRotationGCCalls=1, got %d", n) + } + if calls := gcCalls.reset(); len(calls) != 1 || calls[0] != "deleted" { + t.Errorf("expected [deleted], got %v", calls) + } +} + +// TestCacheGCDisabledCertRotationNoCancel verifies that with the GC feature gate +// disabled, the cert rotation goroutine is NOT stopped (pre-GC behavior). +func TestCacheGCDisabledCertRotationNoCancel(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, false) + _, _, rotationGCCalls, _ := installFakeMetrics(t) + + certFile := writeCAFile(t, []byte(certData)) + keyFile := writeCAFile(t, []byte(keyData)) + + cache := newTLSCache() + + baseline := runtime.NumGoroutine() + + rt, err := cache.get(&Config{ + TLS: TLSConfig{ + CertFile: certFile, + KeyFile: keyFile, + ServerName: "gc-disabled-cert-rotation", + }, + }) + if err != nil { + t.Fatal(err) + } + + afterCreate := runtime.NumGoroutine() + if afterCreate <= baseline { + t.Fatalf("expected cert rotation goroutine to start, got baseline=%d after=%d", baseline, afterCreate) + } + + // Drop rt after this point + // if GC cleanup were enabled, cancel would fire and the goroutine would stop. + runtime.KeepAlive(rt) + + for range 10 { + runtime.GC() + runtime.Gosched() + } + time.Sleep(50 * time.Millisecond) // let any potential cleanup callbacks run + + if n := rotationGCCalls.count.Load(); n != 0 { + t.Errorf("expected TransportCertRotationGCCalls=0 when GC disabled, got %d", n) + } + + // Goroutine must still be running — cancel was intentionally discarded. + if runtime.NumGoroutine() <= baseline { + t.Error("cert rotation goroutine was stopped — it should run indefinitely when GC is disabled") + } +} + +// TestCacheStaleEvictionSkipped verifies that when a GC cleanup fires for an +// old entry after the map slot has been replaced, the deletion is skipped. +// We force this by replacing the map entry while the old trackedTransport +// is still alive, then dropping it so its cleanup fires against the new entry. +func TestCacheStaleEvictionSkipped(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true) + _, gcCalls, _, _ := installFakeMetrics(t) + + cache := newTLSCache() + + cfg := &Config{TLS: TLSConfig{ServerName: "stale-test"}} + rt1, err := cache.get(cfg) + if err != nil { + t.Fatal(err) + } + requireCacheLen(t, cache, 1) + + // Replace the map entry while rt1 is still alive. This simulates a + // concurrent get() that created a new entry for the same key. + key, _, _ := tlsConfigKey(cfg) + cache.mu.Lock() + replacement := &trackedTransport{rt: &http.Transport{}} + cache.transports[key] = weak.Make(replacement) + cache.mu.Unlock() + + // Drop rt1 after this point. Its cleanup fires and sees c.transports[key] != wp -> "skipped". + runtime.KeepAlive(rt1) + gcCalls.reset() + + err = wait.PollUntilContextTimeout(t.Context(), 10*time.Millisecond, 10*time.Second, true, func(_ context.Context) (bool, error) { + runtime.GC() + gcCalls.mu.Lock() + n := len(gcCalls.calls) + gcCalls.mu.Unlock() + return n > 0, nil + }) + if err != nil { + t.Fatal("timed out waiting for stale cleanup to fire") + } + + calls := gcCalls.reset() + if len(calls) != 1 || calls[0] != "skipped" { + t.Errorf("expected [skipped], got %v", calls) + } + + // The replacement entry must still be in the map because of the keep alive + requireCacheLen(t, cache, 1) + runtime.KeepAlive(replacement) +} + +// TestCacheLeak does a black box test of TLS cache GC from New. +// It uses a large number of cache entries generated in parallel. +func TestCacheLeak(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowTLSCacheGC, true) + _, gcCalls, _, cacheEntries := installFakeMetrics(t) + + pollCacheSizeWithGC(t, tlsCache, 0) // clean start with global cache + + rt1, err := New(&Config{TLS: TLSConfig{ServerName: "1"}}) + if err != nil { + t.Fatal(err) + } + rt2, err := New(&Config{TLS: TLSConfig{ServerName: "2"}}) + if err != nil { + t.Fatal(err) + } + rt3, err := New(&Config{TLS: TLSConfig{ServerName: "1"}}) + if err != nil { + t.Fatal(err) + } + + requireCacheLen(t, tlsCache, 2) // rt1 and rt2 (rt3 is the same as rt1) + + var wg wait.Group + var d net.Dialer + var rts []http.RoundTripper + var rtsLock sync.Mutex + for i := range 1_000 { // outer loop forces cache miss via dialer + dh := &DialHolder{Dial: d.DialContext} + for range i%7 + 1 { // inner loop exercises each cache value having 1 to N references + wg.Start(func() { + rt, err := New(&Config{DialHolder: dh}) + if err != nil { + panic(err) + } + rtsLock.Lock() + rts = append(rts, rt) // keep a live reference to the round tripper + rtsLock.Unlock() + }) + } + } + wg.Wait() + + requireCacheLen(t, tlsCache, 1_000+2) // rts and rt1 and rt2 (rt3 is the same as rt1) + + runtime.KeepAlive(rts) // prevent round trippers from being GC'd too early + + // Reset before the eviction we want to measure. + gcCalls.reset() + cacheEntries.reset() + + pollCacheSizeWithGC(t, tlsCache, 2) // rt1 and rt2 (rt3 is the same as rt1) + + calls := gcCalls.reset() + deletedCount := 0 + for _, c := range calls { + if c == "deleted" { + deletedCount++ + } else { + t.Errorf("expected deleted call, got %s", c) + } + } + if deletedCount != 1_000 { + t.Errorf("expected 1000 deleted calls, got %d (total calls: %d)", deletedCount, len(calls)) + } + + // Each "deleted" cleanup calls Observe(lenLocked()). The last observation + // should be 2 (rt1 and rt2 remain). + evictionObservations := cacheEntries.reset() + if len(evictionObservations) != 1000 { + t.Errorf("expected exactly 1000 cache entries observations from GC cleanup, got %d", len(evictionObservations)) + } else if last := evictionObservations[len(evictionObservations)-1]; last != 2 { + t.Errorf("expected last cache entries observation to be 2 (rt1+rt2 remain), got %d", last) + } + + runtime.KeepAlive(rt1) + runtime.KeepAlive(rt2) + runtime.KeepAlive(rt3) + + pollCacheSizeWithGC(t, tlsCache, 0) + + calls = gcCalls.reset() + if len(calls) != 2 || calls[0] != "deleted" || calls[1] != "deleted" { + t.Errorf("expected 2 deleted calls for rt1/rt2, got %v", calls) + } +} + +func requireCacheLen(t *testing.T, c *tlsTransportCache, want int) { + t.Helper() + + if cacheLen(t, c) != want { + t.Fatalf("cache len %d, want %d", cacheLen(t, c), want) + } +} + +func cacheLen(t *testing.T, c *tlsTransportCache) int { + t.Helper() + + c.mu.Lock() + defer c.mu.Unlock() + + if !clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowTLSCacheGC) { + if len(c.transports) != 0 { + t.Fatalf("transports len %d must be 0", len(c.transports)) + } + return len(c.strongTransports) + } + + if len(c.strongTransports) != 0 { + t.Fatalf("strongTransports len %d must be 0", len(c.strongTransports)) + } + return len(c.transports) +} + +func pollCacheSizeWithGC(t *testing.T, c *tlsTransportCache, want int) { + t.Helper() + + if err := wait.PollUntilContextTimeout(t.Context(), 10*time.Millisecond, 10*time.Second, true, func(_ context.Context) (done bool, _ error) { + runtime.GC() // run the garbage collector so the cleanups run + return cacheLen(t, c) == want, nil + }); err != nil { + t.Fatalf("cache len %d, want %d: %v", cacheLen(t, c), want, err) + } + + // make sure the cache size is stable even when more GC's happen + // three times should be enough to make the test flake if the implementation is buggy + for range 3 { + runtime.GC() + } + requireCacheLen(t, c, want) +} + +type recordingCreateCalls struct { + mu sync.Mutex + calls []string +} + +func (r *recordingCreateCalls) Increment(result string) { + r.mu.Lock() + defer r.mu.Unlock() + r.calls = append(r.calls, result) +} + +func (r *recordingCreateCalls) reset() []string { + r.mu.Lock() + defer r.mu.Unlock() + c := r.calls + r.calls = nil + return c +} + +type recordingCacheGCCalls struct { + mu sync.Mutex + calls []string +} + +func (r *recordingCacheGCCalls) Increment(result string) { + r.mu.Lock() + defer r.mu.Unlock() + r.calls = append(r.calls, result) +} + +func (r *recordingCacheGCCalls) reset() []string { + r.mu.Lock() + defer r.mu.Unlock() + c := r.calls + r.calls = nil + return c +} + +type recordingCertRotationGCCalls struct { + count atomic.Int64 +} + +func (r *recordingCertRotationGCCalls) Increment() { + r.count.Add(1) +} + +type recordingCacheEntries struct { + mu sync.Mutex + values []int +} + +func (r *recordingCacheEntries) Observe(n int) { + r.mu.Lock() + defer r.mu.Unlock() + r.values = append(r.values, n) +} + +func (r *recordingCacheEntries) reset() []int { + r.mu.Lock() + defer r.mu.Unlock() + v := r.values + r.values = nil + return v +} + +func installFakeMetrics(t *testing.T) (*recordingCreateCalls, *recordingCacheGCCalls, *recordingCertRotationGCCalls, *recordingCacheEntries) { + createCalls := &recordingCreateCalls{} + gcCalls := &recordingCacheGCCalls{} + rotationGCCalls := &recordingCertRotationGCCalls{} + cacheEntries := &recordingCacheEntries{} + + origCreate := metrics.TransportCreateCalls + origGC := metrics.TransportCacheGCCalls + origRotationGC := metrics.TransportCertRotationGCCalls + origEntries := metrics.TransportCacheEntries + metrics.TransportCreateCalls = createCalls + metrics.TransportCacheGCCalls = gcCalls + metrics.TransportCertRotationGCCalls = rotationGCCalls + metrics.TransportCacheEntries = cacheEntries + t.Cleanup(func() { + metrics.TransportCreateCalls = origCreate + metrics.TransportCacheGCCalls = origGC + metrics.TransportCertRotationGCCalls = origRotationGC + metrics.TransportCacheEntries = origEntries + }) + return createCalls, gcCalls, rotationGCCalls, cacheEntries +} diff --git a/transport/transport_test.go b/transport/transport_test.go index d732ebdef..27cbb5475 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -389,7 +389,10 @@ func TestNew(t *testing.T) { } // We only know how to check TLSConfig on http.Transports - transport := rt.(*http.Transport) + transport, ok := rt.(*http.Transport) + if !ok { + transport = rt.(*trackedTransport).rt.(*http.Transport) + } switch { case testCase.TLS && transport.TLSClientConfig == nil: t.Fatalf("got %#v, expected TLSClientConfig", transport)