diff --git a/features/known_features.go b/features/known_features.go index bb4132c99..6cf49c424 100644 --- a/features/known_features.go +++ b/features/known_features.go @@ -38,6 +38,13 @@ const ( // events for items popped off the FIFO. AtomicFIFO Feature = "AtomicFIFO" + // owner: @yt2985 + // beta: 1.36 + // + // If enabled, allows clients to gracefully handle Certificate Authority (CA) + // rotations without dropping connections or requiring a restart. + ClientsAllowCARotation Feature = "ClientsAllowCARotation" + // owner: @benluddy // kep: https://kep.k8s.io/4222 // alpha: 1.32 @@ -98,6 +105,9 @@ var defaultVersionedKubernetesFeatureGates = map[Feature]VersionedSpecs{ AtomicFIFO: { {Version: version.MustParse("1.36"), Default: true, PreRelease: Beta}, }, + ClientsAllowCARotation: { + {Version: version.MustParse("1.36"), Default: true, PreRelease: Beta}, + }, ClientsAllowCBOR: { {Version: version.MustParse("1.32"), Default: false, PreRelease: Alpha}, }, diff --git a/tools/metrics/metrics.go b/tools/metrics/metrics.go index e364b7e1c..789c3d0a5 100644 --- a/tools/metrics/metrics.go +++ b/tools/metrics/metrics.go @@ -85,6 +85,12 @@ type TransportCreateCallsMetric interface { Increment(result string) } +// TransportCAReloadsMetric counts the number of times a CA reload is attempted, +// partitioned by the result and reason. +type TransportCAReloadsMetric interface { + Increment(result, reason string) +} + var ( // ClientCertExpiry is the expiry time of a client certificate ClientCertExpiry ExpiryMetric = noopExpiry{} @@ -117,6 +123,9 @@ var ( // TransportCreateCalls is the metric that counts the number of times a new transport // is created TransportCreateCalls TransportCreateCallsMetric = noopTransportCreateCalls{} + + // TransportCAReloads is the metric that counts the number of times a CA reload is attempted + TransportCAReloads TransportCAReloadsMetric = noopTransportCAReloads{} ) // RegisterOpts contains all the metrics to register. Metrics may be nil. @@ -134,6 +143,7 @@ type RegisterOpts struct { RequestRetry RetryMetric TransportCacheEntries TransportCacheMetric TransportCreateCalls TransportCreateCallsMetric + TransportCAReloads TransportCAReloadsMetric } // Register registers metrics for the rest client to use. This can @@ -179,6 +189,9 @@ func Register(opts RegisterOpts) { if opts.TransportCreateCalls != nil { TransportCreateCalls = opts.TransportCreateCalls } + if opts.TransportCAReloads != nil { + TransportCAReloads = opts.TransportCAReloads + } }) } @@ -226,3 +239,7 @@ func (noopTransportCache) Observe(int) {} type noopTransportCreateCalls struct{} func (noopTransportCreateCalls) Increment(string) {} + +type noopTransportCAReloads struct{} + +func (noopTransportCAReloads) Increment(result, reason string) {} diff --git a/transport/ca_rotation.go b/transport/ca_rotation.go new file mode 100644 index 000000000..b802eedf4 --- /dev/null +++ b/transport/ca_rotation.go @@ -0,0 +1,154 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package transport + +import ( + "bytes" + "context" + "net/http" + "os" + "sync" + "time" + + utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/client-go/tools/metrics" + "k8s.io/klog/v2" + "k8s.io/utils/clock" +) + +var _ utilnet.RoundTripperWrapper = &atomicTransportHolder{} + +// atomicTransportHolder holds a transport that can be atomically updated +// when CA files change, enabling graceful CA rotation without cache complexity +type atomicTransportHolder struct { + caFile string + currentCAData []byte // Track the actual CA data currently in use + // clock and caRefreshDuration are used to allow for testing time-based logic. + clock clock.Clock + caRefreshDuration time.Duration + // mu covers transport and transportLastChecked + mu sync.RWMutex + transport *http.Transport + transportLastChecked time.Time +} + +func (h *atomicTransportHolder) RoundTrip(req *http.Request) (*http.Response, error) { + return h.getTransport(req.Context()).RoundTrip(req) +} + +func (h *atomicTransportHolder) WrappedRoundTripper() http.RoundTripper { + h.mu.RLock() + defer h.mu.RUnlock() + + return h.transport +} + +func (h *atomicTransportHolder) getTransport(ctx context.Context) *http.Transport { + if rt := h.getTransportIfFresh(); rt != nil { + return rt + } + + h.mu.Lock() + defer h.mu.Unlock() + + h.tryRefreshTransportLocked(ctx) + return h.transport +} + +func (h *atomicTransportHolder) getTransportIfFresh() *http.Transport { + h.mu.RLock() + defer h.mu.RUnlock() + + if h.clock.Since(h.transportLastChecked) < h.caRefreshDuration { + return h.transport + } + return nil +} + +func (h *atomicTransportHolder) tryRefreshTransportLocked(ctx context.Context) { + // If some other goroutine already checked/updated the CA + if h.clock.Since(h.transportLastChecked) < h.caRefreshDuration { + return + } + + // only attempt CA reload once per caRefreshDuration, even if the reload fails + h.transportLastChecked = h.clock.Now() + + logger := klog.FromContext(ctx).WithValues("caFile", h.caFile) + + logger.V(4).Info("Checking CA file content") + + // Load new CA data from file + newCAData, err := os.ReadFile(h.caFile) + // Return old transport on read error + if err != nil { + logger.Error(err, "Failed to read CA data from file") + metrics.TransportCAReloads.Increment("failure", "read_error") + return + } + + if len(newCAData) == 0 { + logger.Info("CA file empty, skipping transport rotation") + metrics.TransportCAReloads.Increment("failure", "empty") + return + } + + if bytes.Equal(h.currentCAData, newCAData) { + logger.V(4).Info("CA file unchanged, skipping transport rotation") + metrics.TransportCAReloads.Increment("success", "unchanged") + return + } + + logger.V(4).Info("CA content changed, updating transport") + + // Load new CA pool + newCAs, err := rootCertPool(newCAData) + // Return old transport on parse error + if err != nil { + logger.Error(err, "Failed to parse CA data from file") + metrics.TransportCAReloads.Increment("failure", "ca_parse_error") + return + } + newTransport := h.transport.Clone() + newTransport.TLSClientConfig.RootCAs = newCAs + oldTransport := h.transport + h.transport = newTransport + // Update our tracking of current CA data + h.currentCAData = newCAData + + // Close idle connections on the old transport to encourage migration + oldTransport.CloseIdleConnections() + + logger.V(4).Info("Transport updated for CA rotation") + metrics.TransportCAReloads.Increment("success", "updated") +} + +// newAtomicTransportHolder creates a new holder for CA file reloading scenarios. +// The caFile must be specified. +// caData may be empty but should correspond to the contents of caFile. +// transport must have a TLS config and its root CAs should match caData. +func newAtomicTransportHolder(caFile string, caData []byte, transport *http.Transport) *atomicTransportHolder { + c := clock.RealClock{} + return &atomicTransportHolder{ + caFile: caFile, + currentCAData: caData, + clock: c, + caRefreshDuration: 5 * time.Minute, + transport: transport, + transportLastChecked: c.Now(), + } +} diff --git a/transport/ca_rotation_test.go b/transport/ca_rotation_test.go new file mode 100644 index 000000000..f4aab8ae2 --- /dev/null +++ b/transport/ca_rotation_test.go @@ -0,0 +1,560 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package transport + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/wait" + clientgofeaturegate "k8s.io/client-go/features" + clientfeaturestesting "k8s.io/client-go/features/testing" + "k8s.io/client-go/tools/metrics" + "k8s.io/client-go/util/cert" + testingclock "k8s.io/utils/clock/testing" +) + +const ( + // Use the same rootCACert as transport_test.go + testCACert1 = `-----BEGIN CERTIFICATE----- +MIIC4DCCAcqgAwIBAgIBATALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu +MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIxNTczN1oXDTE2MDExNTIxNTcz +OFowIzEhMB8GA1UEAwwYMTAuMTMuMTI5LjEwNkAxNDIxMzU5MDU4MIIBIjANBgkq +hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAunDRXGwsiYWGFDlWH6kjGun+PshDGeZX +xtx9lUnL8pIRWH3wX6f13PO9sktaOWW0T0mlo6k2bMlSLlSZgG9H6og0W6gLS3vq +s4VavZ6DbXIwemZG2vbRwsvR+t4G6Nbwelm6F8RFnA1Fwt428pavmNQ/wgYzo+T1 +1eS+HiN4ACnSoDSx3QRWcgBkB1g6VReofVjx63i0J+w8Q/41L9GUuLqquFxu6ZnH +60vTB55lHgFiDLjA1FkEz2dGvGh/wtnFlRvjaPC54JH2K1mPYAUXTreoeJtLJKX0 +ycoiyB24+zGCniUmgIsmQWRPaOPircexCp1BOeze82BT1LCZNTVaxQIDAQABoyMw +ITAOBgNVHQ8BAf8EBAMCAKQwDwYDVR0TAQH/BAUwAwEB/zALBgkqhkiG9w0BAQsD +ggEBADMxsUuAFlsYDpF4fRCzXXwrhbtj4oQwcHpbu+rnOPHCZupiafzZpDu+rw4x +YGPnCb594bRTQn4pAu3Ac18NbLD5pV3uioAkv8oPkgr8aUhXqiv7KdDiaWm6sbAL +EHiXVBBAFvQws10HMqMoKtO8f1XDNAUkWduakR/U6yMgvOPwS7xl0eUTqyRB6zGb +K55q2dejiFWaFqB/y78txzvz6UlOZKE44g2JAVoJVM6kGaxh33q8/FmrL4kuN3ut +W+MmJCVDvd4eEqPwbp7146ZWTqpIJ8lvA6wuChtqV8lhAPka2hD/LMqY8iXNmfXD +uml0obOEy+ON91k+SWTJ3ggmF/U= +-----END CERTIFICATE-----` + + // A different CA cert for testing rotation (modified version of certData from transport_test.go) + testCACert2 = `-----BEGIN CERTIFICATE----- +MIIC6jCCAdSgAwIBAgIBCzALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu +MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIyMDEzMVoXDTE2MDExNTIyMDEz +MlowGzEZMBcGA1UEAxMQb3BlbnNoaWZ0LWNsaWVudDCCASIwDQYJKoZIhvcNAQEB +BQADggEPADCCAQoCggEBAKtdhz0+uCLXw5cSYns9rU/XifFSpb/x24WDdrm72S/v +b9BPYsAStiP148buylr1SOuNi8sTAZmlVDDIpIVwMLff+o2rKYDicn9fjbrTxTOj +lI4pHJBH+JU3AJ0tbajupioh70jwFS0oYpwtneg2zcnE2Z4l6mhrj2okrc5Q1/X2 +I2HChtIU4JYTisObtin10QKJX01CLfYXJLa8upWzKZ4/GOcHG+eAV3jXWoXidtjb +1Usw70amoTZ6mIVCkiu1QwCoa8+ycojGfZhvqMsAp1536ZcCul+Na+AbCv4zKS7F +kQQaImVrXdUiFansIoofGlw/JNuoKK6ssVpS5Ic3pgcCAwEAAaM1MDMwDgYDVR0P +AQH/BAQDAgCgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMAwGA1UdEwEB/wQCMAAwCwYJ +KoZIhvcNAQELA4IBAQCKLREH7bXtXtZ+8vI6cjD7W3QikiArGqbl36bAhhWsJLp/ +p/ndKz39iFNaiZ3GlwIURWOOKx3y3GA0x9m8FR+Llthf0EQ8sUjnwaknWs0Y6DQ3 +jjPFZOpV3KPCFrdMJ3++E3MgwFC/Ih/N2ebFX9EcV9Vcc6oVWMdwT0fsrhu683rq +6GSR/3iVX1G/pmOiuaR0fNUaCyCfYrnI4zHBDgSfnlm3vIvN2lrsR/DQBakNL8DJ +HBgKxMGeUPoneBv+c8DMXIL0EhaFXRlBv9QW45/GiAIOuyFJ0i6hCtGZpJjq4OpQ +BRjCI+izPzFTjsxD4aORE+WOkyWFCGPWKfNejfw0 +-----END CERTIFICATE-----` +) + +// writeCAFile writes CA data to a temporary file +func writeCAFile(t testing.TB, caData []byte) string { + tmpDir := t.TempDir() + caFile := filepath.Join(tmpDir, "ca.crt") + + err := os.WriteFile(caFile, caData, 0644) + if err != nil { + t.Fatalf("Failed to write CA file: %v", err) + } + t.Cleanup(func() { + if err := os.Remove(caFile); err != nil { + t.Logf("unexpected error while removing file: %s - %v", caFile, err) + } + }) + return caFile +} + +// createTestTransport creates a test transport with TLS config +func createTestTransport(t testing.TB, caData []byte) *http.Transport { + CAs, err := rootCertPool(caData) + if err != nil { + t.Fatalf("Failed to parse CA certificate") + } + return &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: CAs, + }, + } +} + +func TestCheckCAFileAndRotate(t *testing.T) { + tests := []struct { + name string + setupCA []byte + updateCA []byte + caFileOverride string + expectRotation bool + }{ + { + name: "no change", + setupCA: []byte(testCACert1), + updateCA: []byte(testCACert1), // Same CA + expectRotation: false, + }, + { + name: "CA changed", + setupCA: []byte(testCACert1), + updateCA: []byte(testCACert2), // Different CA + expectRotation: true, + }, + { + name: "CA changed to invalid", + setupCA: []byte(testCACert1), + updateCA: []byte("panda"), // invalid CA + expectRotation: false, + }, + { + name: "file error", + setupCA: []byte(testCACert1), + caFileOverride: "/nonexistent/ca.crt", // Non-existent file + expectRotation: false, + }, + { + name: "empty file content", + setupCA: []byte(testCACert1), + updateCA: []byte{}, // Empty file + expectRotation: false, + }, + { + name: "initially empty CA file updated to valid CA", + setupCA: []byte{}, // Starts with empty CA + updateCA: []byte(testCACert1), // Populated with a valid CA + expectRotation: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + caFile := writeCAFile(t, tt.setupCA) + if len(tt.caFileOverride) > 0 { + caFile = tt.caFileOverride + } + + transport := createTestTransport(t, tt.setupCA) + setupRoots := transport.TLSClientConfig.RootCAs.Clone() + + expectedRoots := setupRoots + if tt.expectRotation { + var err error + expectedRoots, err = rootCertPool(tt.updateCA) + if err != nil { + t.Fatal(err) + } + } + + clock := testingclock.NewFakeClock(time.Now()) + holder := newAtomicTransportHolder(caFile, tt.setupCA, transport) + holder.clock = clock + holder.transportLastChecked = clock.Now() + + if tt.updateCA != nil { + // Update the file with new CA content + err := os.WriteFile(caFile, tt.updateCA, 0644) + if err != nil { + t.Errorf("Failed to update CA data with file address: %s", caFile) + } + } + + clock.Step(holder.caRefreshDuration) + + // Check CA file rotation + newTransport := holder.getTransport(t.Context()) + newRoots := newTransport.TLSClientConfig.RootCAs + + if newRoots == nil || !expectedRoots.Equal(newRoots) { + t.Error("new roots did not match expected roots") + } + + transportRotated := newTransport != transport + if tt.expectRotation != transportRotated { + t.Error("transport rotation did not match") + } + + }) + } +} + +func generateServerCertAndCA(t testing.TB) (servingCertPEM, servingKeyPEM, caCertPEM []byte) { + t.Helper() + certPEM, keyPEM, err := cert.GenerateSelfSignedCertKey("127.0.0.1", nil, nil) + if err != nil { + t.Fatalf("Failed to generate server cert: %v", err) + } + certs, err := cert.ParseCertsPEM(certPEM) + if err != nil || len(certs) < 2 { + t.Fatal("Expected cert chain with [leaf, CA]") + } + caPEM, err := cert.EncodeCertificates(certs[len(certs)-1]) + if err != nil { + t.Fatalf("Failed to encode CA cert: %v", err) + } + return certPEM, keyPEM, caPEM +} + +// TestCARotationConnectionBehavior tests end-to-end CA rotation: +// 1. Client trusts server CA via a CA file on disk +// 2. Server rotates to a new CA + serving cert +// 3. Client fails (doesn't trust new CA yet) +// 4. CA file updated, transport reloads, client reconnects +func TestCARotationConnectionBehavior(t *testing.T) { + t.Log("Testing CA Rotation Connection Behavior") + + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true) + + // Generate initial server v1 cert and CA + servingCertPEM1, servingKeyPEM1, caCertPEM1 := generateServerCertAndCA(t) + + // Start server v1 (no client cert required - test focuses on server CA rotation) + srv1 := newTestServer(t, servingCertPEM1, servingKeyPEM1) + srv1.StartTLS() + defer srv1.Close() + + // Set up the client + clientCAFile := writeCAFile(t, caCertPEM1) + config := &Config{ + TLS: TLSConfig{ + CAFile: clientCAFile, + }, + } + + transport, err := New(config) + if err != nil { + t.Fatalf("Failed to create transport: %v", err) + } + transport.(*atomicTransportHolder).caRefreshDuration = 500 * time.Millisecond + + client := &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + } + + // Initial connection must succeed + t.Log("Making initial request to server v1, expecting success...") + resp, err := client.Get(srv1.URL) + if err != nil { + t.Fatalf("Failed to call the server v1: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatal("Failed to close the response.") + } + if resp.StatusCode != http.StatusOK { + t.Fatal("Failed to call the server successfully.") + } + + t.Log("Initial connection successful.") + + // Rotate: new CA, new server cert, same address + t.Log("Stopping server v1 and starting server v2 with new CA...") + srv1Addr := srv1.Listener.Addr().String() + srv1.Close() + + servingCertPEM2, servingKeyPEM2, caCertPEM2 := generateServerCertAndCA(t) + srv2 := newTestServer(t, servingCertPEM2, servingKeyPEM2) + l, err := net.Listen("tcp", srv1Addr) + if err != nil { + t.Fatalf("Failed to re-claim the same server address: %v", err) + } + + srv2.Listener = l + srv2.StartTLS() + defer srv2.Close() + + // Must fail - client still trusts old CA + t.Log("Making request to server v2, expecting failure...") + _, err = client.Get(srv2.URL) + if err == nil { + t.Fatal("The request should fail.") + } + t.Log("Request failed as expected.") + + // Update CA file to trust new CA + t.Log("Updating client CA file on disk to trust new CA...") + if err := os.WriteFile(clientCAFile, caCertPEM2, 0644); err != nil { + t.Fatalf("Failed to update CA file: %v", err) + } + + // Poll until transport reloads + t.Log("Polling server v2 until the client's transport reloads the new CA...") + var lastPollErr error + ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer ctxCancel() + + pollErr := wait.PollUntilContextCancel(ctx, 500*time.Millisecond, true, func(ctx context.Context) (bool, error) { + resp, err := client.Get(srv2.URL) + if err != nil { + lastPollErr = err + t.Log("Client failed to connect before the root CAs are updated, will retry...") + return false, nil // Error is expected, continue polling + } + if err := resp.Body.Close(); err != nil { + t.Fatal("Failed to close the response.") + } + if resp.StatusCode == http.StatusOK { + return true, nil // Success! Stop polling. + } + return false, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + }) + if pollErr != nil { + t.Fatalf("Client failed to reconnect after CA rotation. Last error: %v. Test error: %v", lastPollErr, pollErr) + } + + t.Log("Success! Client reconnected after CA was refreshed.") +} + +func TestCARotationConnectionBehavior_Disabled(t *testing.T) { + t.Log("Testing CA Rotation Connection Behavior (Feature Disabled)") + + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false) + + // Generate initial server v1 cert and CA + servingCertPEM1, servingKeyPEM1, caCertPEM1 := generateServerCertAndCA(t) + + // Start server v1 (no client cert required - test focuses on server CA rotation) + srv1 := newTestServer(t, servingCertPEM1, servingKeyPEM1) + srv1.StartTLS() + defer srv1.Close() + + // Set up the client + clientCAFile := writeCAFile(t, caCertPEM1) + config := &Config{ + TLS: TLSConfig{ + CAFile: clientCAFile, + }, + } + + transport, err := New(config) + if err != nil { + t.Fatalf("Failed to create transport: %v", err) + } + if _, ok := transport.(*atomicTransportHolder); ok { + t.Fatal("Expected plain transport when the feature gate is disabled, got atomicTransportHolder") + } + + client := &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + } + + // Initial connection must succeed + t.Log("Making initial request to server v1, expecting success...") + resp, err := client.Get(srv1.URL) + if err != nil { + t.Fatalf("Failed to call the server v1: %v", err) + } + if err := resp.Body.Close(); err != nil { + t.Fatal("Failed to close the response.") + } + if resp.StatusCode != http.StatusOK { + t.Fatal("Failed to call the server successfully.") + } + + t.Log("Initial connection successful.") + + // Rotate: new CA, new server cert, same address + t.Log("Stopping server v1 and starting server v2 with new CA...") + srv1Addr := srv1.Listener.Addr().String() + srv1.Close() + + servingCertPEM2, servingKeyPEM2, caCertPEM2 := generateServerCertAndCA(t) + srv2 := newTestServer(t, servingCertPEM2, servingKeyPEM2) + l, err := net.Listen("tcp", srv1Addr) + if err != nil { + t.Fatalf("Failed to re-claim the same server address: %v", err) + } + + srv2.Listener = l + srv2.StartTLS() + defer srv2.Close() + + // Must fail - client still trusts old CA + t.Log("Making request to server v2, expecting failure...") + _, err = client.Get(srv2.URL) + if err == nil { + t.Fatal("The request should fail.") + } + t.Log("Request failed as expected.") + + // Update CA file to trust new CA + t.Log("Updating client CA file on disk to trust new CA...") + if err := os.WriteFile(clientCAFile, caCertPEM2, 0644); err != nil { + t.Fatalf("Failed to update CA file: %v", err) + } + + // Poll to ensure transport DOES NOT reload + t.Log("Polling server v2 to verify the client DOES NOT reconnect...") + var lastPollErr error + ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer ctxCancel() + + pollErr := wait.PollUntilContextCancel(ctx, 500*time.Millisecond, true, func(ctx context.Context) (bool, error) { + resp, err := client.Get(srv2.URL) + if err != nil { + lastPollErr = err + t.Log("Client failed to connect (expected because feature is disabled)...") + return false, nil // Keep polling until the timeout is reached + } + if err := resp.Body.Close(); err != nil { + t.Fatal("Failed to close the response.") + } + if resp.StatusCode == http.StatusOK { + return true, nil // Success! But this means the test failed. + } + return false, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + }) + + // If pollErr is nil, it means the connection succeeded, which is wrong for this test. + if pollErr == nil { + t.Fatalf("Client unexpectedly reconnected after CA rotation! The feature gate is disabled, so this should have failed.") + } + + t.Logf("Success! Client permanently failed to reconnect as expected. Last error: %v", lastPollErr) +} + +// testCAReloadsMetric is a fake metric recorder that records calls to Increment. +type testCAReloadsMetric struct { + calls []caReloadCall +} + +type caReloadCall struct { + result, reason string +} + +func (m *testCAReloadsMetric) Increment(result, reason string) { + m.calls = append(m.calls, caReloadCall{result, reason}) +} + +// TestCARotationMetricsEmitted verifies that ca_rotation.go emits the correct +// metrics during actual CA reload operations. +func TestCARotationMetricsEmitted(t *testing.T) { + fakeMetricRecorder := &testCAReloadsMetric{} + origMetric := metrics.TransportCAReloads + metrics.TransportCAReloads = fakeMetricRecorder + t.Cleanup(func() { metrics.TransportCAReloads = origMetric }) + + caData := []byte(testCACert1) + caFile := writeCAFile(t, caData) + transport := createTestTransport(t, caData) + + clock := testingclock.NewFakeClock(time.Now()) + holder := newAtomicTransportHolder(caFile, caData, transport) + holder.clock = clock + holder.transportLastChecked = clock.Now() + + tests := []struct { + name string + setup func() error + wantResult string + wantReason string + }{ + { + name: "unchanged CA", + setup: func() error { return nil }, + wantResult: "success", + wantReason: "unchanged", + }, + { + name: "updated CA", + setup: func() error { + return os.WriteFile(caFile, []byte(testCACert2), 0644) + }, + wantResult: "success", + wantReason: "updated", + }, + { + name: "empty file", + setup: func() error { + return os.WriteFile(caFile, []byte{}, 0644) + }, + wantResult: "failure", + wantReason: "empty", + }, + { + name: "invalid CA data", + setup: func() error { + return os.WriteFile(caFile, []byte("not-a-cert"), 0644) + }, + wantResult: "failure", + wantReason: "ca_parse_error", + }, + { + name: "read error", + setup: func() error { + return os.Remove(caFile) + }, + wantResult: "failure", + wantReason: "read_error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeMetricRecorder.calls = nil + err := tt.setup() + if err != nil { + t.Fatalf("Setup failed: %v", err) + } + clock.Step(holder.caRefreshDuration) + holder.getTransport(context.Background()) + + if len(fakeMetricRecorder.calls) != 1 { + t.Fatalf("Expected 1 metric call, got %d", len(fakeMetricRecorder.calls)) + } + if fakeMetricRecorder.calls[0].result != tt.wantResult || fakeMetricRecorder.calls[0].reason != tt.wantReason { + t.Errorf("Got metric(%s, %s), want (%s, %s)", + fakeMetricRecorder.calls[0].result, fakeMetricRecorder.calls[0].reason, tt.wantResult, tt.wantReason) + } + }) + } +} + +// helper to create a simple, non-blocking test server with a given certificate. +func newTestServer(t *testing.T, certPEM, keyPEM []byte) *httptest.Server { + t.Helper() + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := fmt.Fprint(w, "ok"); err != nil { + t.Fatal("Failed to write to the response.") + } + })) + // Configure server TLS + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("Failed to create server cert: %v", err) + } + + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + } + return server +} diff --git a/transport/cache.go b/transport/cache.go index b8dd86611..ae1c39ffe 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -36,7 +36,7 @@ import ( // the config has no custom TLS options, http.DefaultTransport is returned. type tlsTransportCache struct { mu sync.Mutex - transports map[tlsCacheKey]*http.Transport + transports map[tlsCacheKey]http.RoundTripper } // DialerStopCh is stop channel that is passed down to dynamic cert dialer. @@ -46,11 +46,14 @@ var DialerStopCh = wait.NeverStop const idleConnsPerHost = 25 -var tlsCache = &tlsTransportCache{transports: make(map[tlsCacheKey]*http.Transport)} +var tlsCache = &tlsTransportCache{ + transports: make(map[tlsCacheKey]http.RoundTripper), +} type tlsCacheKey struct { insecure bool caData string + caFile string certData string keyData string `datapolicy:"security-key"` certFile string @@ -68,8 +71,8 @@ func (t tlsCacheKey) String() string { if len(t.keyData) > 0 { keyText = "" } - return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p", - t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial) + return fmt.Sprintf("insecure:%v, caData:%#v, caFile:%s, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p", + t.insecure, t.caData, t.caFile, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial) } func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { @@ -131,7 +134,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { proxy = config.Proxy } - transport := utilnet.SetTransportDefaults(&http.Transport{ + httpTransport := utilnet.SetTransportDefaults(&http.Transport{ Proxy: proxy, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, @@ -139,6 +142,11 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { DialContext: dial, DisableCompression: config.DisableCompression, }) + var transport http.RoundTripper = httpTransport + + if config.TLS.ReloadCAFiles && tlsConfig != nil && tlsConfig.RootCAs != nil && len(config.TLS.CAFile) > 0 { + transport = newAtomicTransportHolder(config.TLS.CAFile, config.TLS.CAData, httpTransport) + } if canCache { // Cache a single transport for these options @@ -162,7 +170,6 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { k := tlsCacheKey{ insecure: c.TLS.Insecure, - caData: string(c.TLS.CAData), serverName: c.TLS.ServerName, nextProtos: strings.Join(c.TLS.NextProtos, ","), disableCompression: c.DisableCompression, @@ -178,5 +185,14 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { k.keyData = string(c.TLS.KeyData) } + if c.TLS.ReloadCAFiles { + // When reloading CA files, include CA file path in cache key instead of CA data + // This allows the CA to be reloaded from disk on each transport creation + k.caFile = c.TLS.CAFile + } else { + // When not reloading, cache the CA data directly + k.caData = string(c.TLS.CAData) + } + return k, true, nil } diff --git a/transport/cache_test.go b/transport/cache_test.go index 54705276d..e7cb74019 100644 --- a/transport/cache_test.go +++ b/transport/cache_test.go @@ -19,14 +19,22 @@ package transport import ( "context" "crypto/tls" + "crypto/x509" "net" "net/http" "net/url" + "os" "reflect" "testing" + "time" + + clientgofeaturegate "k8s.io/client-go/features" + clientfeaturestesting "k8s.io/client-go/features/testing" ) func TestTLSConfigKey(t *testing.T) { + + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true) // Make sure config fields that don't affect the tls config don't affect the cache key identicalConfigurations := map[string]*Config{ "empty": {}, @@ -70,14 +78,17 @@ func TestTLSConfigKey(t *testing.T) { // Make sure config fields that affect the tls config affect the cache key dialer := net.Dialer{} getCert := &GetCertHolder{GetCert: func() (*tls.Certificate, error) { return nil, nil }} + caFile := writeCAFile(t, []byte(testCACert1)) uniqueConfigurations := map[string]*Config{ - "proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }}, - "no tls": {}, - "dialer": {DialHolder: &DialHolder{Dial: dialer.DialContext}}, - "dialer2": {DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}}, - "insecure": {TLS: TLSConfig{Insecure: true}}, - "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, - "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, + "proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }}, + "no tls": {}, + "dialer": {DialHolder: &DialHolder{Dial: dialer.DialContext}}, + "dialer2": {DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}}, + "insecure": {TLS: TLSConfig{Insecure: true}}, + "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, + "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, + "with only ca file": {TLS: TLSConfig{CAFile: caFile}}, + "with both ca file and ca data": {TLS: TLSConfig{CAFile: caFile, CAData: []byte(testCACert1)}}, "cert 1, key 1": { TLS: TLSConfig{ CertData: []byte{1}, @@ -188,3 +199,228 @@ func TestTLSConfigKey(t *testing.T) { } } } + +func TestTLSConfigKeyCARotationDisabled(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false) + + caFile := writeCAFile(t, []byte(testCACert1)) + + // When feature is disabled, CAFile-only config resolves CAData via + // loadTLSFiles, so two configs with the same file content get the same key. + config1 := &Config{TLS: TLSConfig{CAFile: caFile}} + config2 := &Config{TLS: TLSConfig{CAData: []byte(testCACert1)}} + + if err := loadTLSFiles(config1); err != nil { + t.Fatal(err) + } + if err := loadTLSFiles(config2); err != nil { + t.Fatal(err) + } + + key1, canCache1, err := tlsConfigKey(config1) + if err != nil || !canCache1 { + t.Fatalf("unexpected: err=%v, canCache=%v", err, canCache1) + } + + key2, canCache2, err := tlsConfigKey(config2) + if err != nil || !canCache2 { + t.Fatalf("unexpected: err=%v, canCache=%v", err, canCache2) + } + + if key1 != key2 { + t.Error("Expected same cache key when feature is disabled (CAFile resolved to CAData)") + } + if config1.TLS.ReloadCAFiles { + t.Error("Expected ReloadCAFiles=false when feature gate is disabled") + } +} + +// TestTLSTransportCacheCARotation tests transport cache behavior with CA rotation +func TestTLSTransportCacheCARotation(t *testing.T) { + + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true) + caFile := writeCAFile(t, []byte(testCACert1)) + testCases := []struct { + name string + config *Config + expectWrapper bool + expectCacheable bool + }{ + { + name: "CA rotation should be enabled when only the CAFile is set", + config: &Config{ + TLS: TLSConfig{ + CAFile: caFile, + }, + }, + expectWrapper: true, + expectCacheable: true, + }, + { + name: "CA rotation should be disabled when both CAFile and CAData are set", + config: &Config{ + TLS: TLSConfig{ + CAFile: caFile, + CAData: []byte(testCACert1), + }, + }, + expectWrapper: false, + expectCacheable: true, + }, + { + name: "CA rotation should be disabled when only the CAData is set", + config: &Config{ + TLS: TLSConfig{ + CAData: []byte(testCACert1), + }, + }, + expectWrapper: false, + expectCacheable: true, + }, + { + name: "no TLS config", + config: &Config{ + TLS: TLSConfig{}, + }, + expectWrapper: false, + expectCacheable: false, // No TLS config means default transport + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create new cache for testing + tlsCaches := &tlsTransportCache{ + transports: make(map[tlsCacheKey]http.RoundTripper), + } + + rt, err := tlsCaches.get(tc.config) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !tc.expectCacheable { + // Should return default transport + if rt != http.DefaultTransport { + t.Errorf("Expected default transport, got %T", rt) + } + return + } + + if tc.expectWrapper { + // Should be wrapped in atomicTransportHolder + if _, ok := rt.(*atomicTransportHolder); !ok { + t.Errorf("Expected atomicTransportHolder, got %T", rt) + } + if !tc.config.TLS.ReloadCAFiles { + t.Errorf("Expected ReloadCAFiles to be true, got %v", tc.config.TLS.ReloadCAFiles) + } + } else { + // Should be a regular http.Transport + if _, ok := rt.(*http.Transport); !ok { + t.Errorf("Expected *http.Transport, got %T", rt) + } + } + + // Test caching: second call should return the same instance + rt2, err := tlsCaches.get(tc.config) + if err != nil { + t.Fatalf("Unexpected error on second call: %v", err) + } + + if rt != rt2 { + t.Error("Expected same transport instance from cache") + } + + // Verify cache size + tlsCaches.mu.Lock() + cacheSize := len(tlsCaches.transports) + tlsCaches.mu.Unlock() + + expectedCacheSize := 1 + if !tc.expectCacheable { + expectedCacheSize = 0 + } + + if cacheSize != expectedCacheSize { + t.Errorf("Expected %d transports in cache, got %d", expectedCacheSize, cacheSize) + } + }) + } +} + +func TestTLSTransportCacheCARotationDisabled(t *testing.T) { + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false) + + caFile := writeCAFile(t, []byte(testCACert1)) + cache := &tlsTransportCache{transports: make(map[tlsCacheKey]http.RoundTripper)} + + rt, err := cache.get(&Config{TLS: TLSConfig{CAFile: caFile}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if _, ok := rt.(*atomicTransportHolder); ok { + t.Error("Expected plain *http.Transport when feature gate is disabled, got atomicTransportHolder") + } + if _, ok := rt.(*http.Transport); !ok { + t.Errorf("Expected *http.Transport, got %T", rt) + } +} + +func TestEmptyCAFileRotationLifecycle(t *testing.T) { + // Enable the feature gate for the duration of the test + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true) + + // Create a valid file path, but with empty cert data + emptyFile := writeCAFile(t, []byte{}) + + config := &Config{ + TLS: TLSConfig{ + CAFile: emptyFile, + }, + } + + tlsCaches := &tlsTransportCache{ + transports: make(map[tlsCacheKey]http.RoundTripper), + } + + rt, err := tlsCaches.get(config) + if err != nil { + t.Fatalf("Unexpected error getting transport: %v", err) + } + + // Verify newAtomicTransportHolder is successfully generated + holder, ok := rt.(*atomicTransportHolder) + if !ok { + t.Fatalf("Expected atomicTransportHolder, got %T", rt) + } + + // Verify the initial state: RootCAs should be non-nil but empty + initialTransport := holder.getTransport(context.Background()) + + if initialTransport.TLSClientConfig == nil || initialTransport.TLSClientConfig.RootCAs == nil { + t.Fatal("Expected RootCAs to be non-nil for an empty CA file (should be an empty CertPool)") + } + emptyPool := x509.NewCertPool() + if !initialTransport.TLSClientConfig.RootCAs.Equal(emptyPool) { + t.Fatal("Expected initially empty RootCAs") + } + + // Write valid cert data into the CA file + if err := os.WriteFile(emptyFile, []byte(testCACert1), 0644); err != nil { + t.Fatalf("Failed to write to CA file: %v", err) + } + + holder.mu.Lock() + // Set last checked time far in the past to force a refresh on next getTransport call + holder.transportLastChecked = time.Now().Add(-time.Hour) + holder.mu.Unlock() + + refreshedTransport := holder.getTransport(context.Background()) + + // Verify the refresh succeeded and the cert pool is now populated + if refreshedTransport.TLSClientConfig.RootCAs.Equal(emptyPool) { + t.Fatal("Expected RootCAs to be populated after writing valid cert data and refreshing") + } +} diff --git a/transport/config.go b/transport/config.go index d8a3d64b3..f3ddcca33 100644 --- a/transport/config.go +++ b/transport/config.go @@ -134,7 +134,8 @@ type TLSConfig struct { CAFile string // Path of the PEM-encoded server trusted root certificates. CertFile string // Path of the PEM-encoded client certificate. KeyFile string // Path of the PEM-encoded client key. - ReloadTLSFiles bool // Set to indicate that the original config provided files, and that they should be reloaded + ReloadTLSFiles bool // Set to indicate that the original config provided files, and that they should be reloaded. + ReloadCAFiles bool // Set to indicate that CA files should be reloaded from disk. Insecure bool // Server should be accessed without verifying the certificate. For testing only. ServerName string // Override for the server name passed to the server for SNI and used to verify certificates. diff --git a/transport/transport.go b/transport/transport.go index 8fdcc5700..be97b7621 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -28,6 +28,7 @@ import ( "time" utilnet "k8s.io/apimachinery/pkg/util/net" + clientgofeaturegate "k8s.io/client-go/features" "k8s.io/klog/v2" ) @@ -211,17 +212,26 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { // KeyData, and CAFile fields, or returns an error. If no error is returned, all three fields are // either populated or were empty to start. func loadTLSFiles(c *Config) error { + // Check that we are purely loading CA from file + if clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowCARotation) { + if len(c.TLS.CAFile) > 0 && len(c.TLS.CAData) == 0 { + c.TLS.ReloadCAFiles = true + } + } else if c.TLS.ReloadCAFiles { + return fmt.Errorf("ReloadCAFiles=true requires ClientsAllowCARotation to be enabled") + } + + // Check that we are purely loading certs and keys from files + if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 { + c.TLS.ReloadTLSFiles = true + } + var err error c.TLS.CAData, err = dataFromSliceOrFile(c.TLS.CAData, c.TLS.CAFile) if err != nil { return err } - // Check that we are purely loading from files - if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 { - c.TLS.ReloadTLSFiles = true - } - c.TLS.CertData, err = dataFromSliceOrFile(c.TLS.CertData, c.TLS.CertFile) if err != nil { return err @@ -254,6 +264,11 @@ func rootCertPool(caData []byte) (*x509.CertPool, error) { // code for a look at the platform specific insanity), so we'll use the fact that RootCAs == nil gives us the system values // It doesn't allow trusting either/or, but hopefully that won't be an issue if len(caData) == 0 { + // When the ClientsAllowCARotation feature gate is enabled, it returns an empty but non-nil pool. + // This ensures we don't fall back to system roots when a user explicitly points CAFile to a zero-byte file. + if clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowCARotation) { + return x509.NewCertPool(), nil + } return nil, nil } diff --git a/transport/transport_test.go b/transport/transport_test.go index 188044220..d732ebdef 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -19,11 +19,15 @@ package transport import ( "context" "crypto/tls" + "crypto/x509" "errors" "fmt" "net" "net/http" "testing" + + clientgofeaturegate "k8s.io/client-go/features" + clientfeaturestesting "k8s.io/client-go/features/testing" ) const ( @@ -356,6 +360,16 @@ func TestNew(t *testing.T) { } for k, testCase := range testCases { t.Run(k, func(t *testing.T) { + // The Close method of httptest Server mutates the + // `http.DefaultTransport` object, the 'TLSClientConfig' + // field mutates from nil to a non nil instance. This introduces flake + // and data race when running tests under transport package in parallel. + // To work around it we reset the TLSClientConfig field. + // + // See: https://github.com/golang/go/issues/65796 + if testCase.Default { + http.DefaultTransport.(*http.Transport).TLSClientConfig = nil + } rt, err := New(testCase.Config) switch { case testCase.Err && err == nil: @@ -556,3 +570,51 @@ func Test_contextCanceller_RoundTrip(t *testing.T) { }) } } + +func TestRootCertPoolEmptyData(t *testing.T) { + testCases := []struct { + name string + featureGateEnabled bool + expectNilPool bool + }{ + { + name: "feature gate disabled returns nil (system roots)", + featureGateEnabled: false, + expectNilPool: true, + }, + { + name: "feature gate enabled returns empty pool (trust nothing)", + featureGateEnabled: true, + expectNilPool: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set the feature gate according to the test case + clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, tc.featureGateEnabled) + + // Call the function with empty caData + pool, err := rootCertPool([]byte{}) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if tc.expectNilPool { + if pool != nil { + t.Fatalf("Expected pool to be nil when feature gate is disabled, but got a populated pool") + } + } else { + if pool == nil { + t.Fatalf("Expected pool to be non-nil (empty pool) when feature gate is enabled, but got nil") + } + + // Verify it is truly an empty pool + emptyPool := x509.NewCertPool() + if !pool.Equal(emptyPool) { + t.Fatalf("Expected the returned pool to be completely empty") + } + } + }) + } +}