client-go CA rotation

Kubernetes-commit: b806a3207beb1afabc3d3d93fbadba38ffb8a110
This commit is contained in:
tinatingyu
2026-03-02 20:40:00 +00:00
committed by Kubernetes Publisher
parent 6f2c112d22
commit 8776b282cc
9 changed files with 1090 additions and 19 deletions

View File

@@ -38,6 +38,13 @@ const (
// events for items popped off the FIFO.
AtomicFIFO Feature = "AtomicFIFO"
// owner: @yt2985
// beta: 1.36
//
// If enabled, allows clients to gracefully handle Certificate Authority (CA)
// rotations without dropping connections or requiring a restart.
ClientsAllowCARotation Feature = "ClientsAllowCARotation"
// owner: @benluddy
// kep: https://kep.k8s.io/4222
// alpha: 1.32
@@ -98,6 +105,9 @@ var defaultVersionedKubernetesFeatureGates = map[Feature]VersionedSpecs{
AtomicFIFO: {
{Version: version.MustParse("1.36"), Default: true, PreRelease: Beta},
},
ClientsAllowCARotation: {
{Version: version.MustParse("1.36"), Default: true, PreRelease: Beta},
},
ClientsAllowCBOR: {
{Version: version.MustParse("1.32"), Default: false, PreRelease: Alpha},
},

View File

@@ -85,6 +85,12 @@ type TransportCreateCallsMetric interface {
Increment(result string)
}
// TransportCAReloadsMetric counts the number of times a CA reload is attempted,
// partitioned by the result and reason.
type TransportCAReloadsMetric interface {
Increment(result, reason string)
}
var (
// ClientCertExpiry is the expiry time of a client certificate
ClientCertExpiry ExpiryMetric = noopExpiry{}
@@ -117,6 +123,9 @@ var (
// TransportCreateCalls is the metric that counts the number of times a new transport
// is created
TransportCreateCalls TransportCreateCallsMetric = noopTransportCreateCalls{}
// TransportCAReloads is the metric that counts the number of times a CA reload is attempted
TransportCAReloads TransportCAReloadsMetric = noopTransportCAReloads{}
)
// RegisterOpts contains all the metrics to register. Metrics may be nil.
@@ -134,6 +143,7 @@ type RegisterOpts struct {
RequestRetry RetryMetric
TransportCacheEntries TransportCacheMetric
TransportCreateCalls TransportCreateCallsMetric
TransportCAReloads TransportCAReloadsMetric
}
// Register registers metrics for the rest client to use. This can
@@ -179,6 +189,9 @@ func Register(opts RegisterOpts) {
if opts.TransportCreateCalls != nil {
TransportCreateCalls = opts.TransportCreateCalls
}
if opts.TransportCAReloads != nil {
TransportCAReloads = opts.TransportCAReloads
}
})
}
@@ -226,3 +239,7 @@ func (noopTransportCache) Observe(int) {}
type noopTransportCreateCalls struct{}
func (noopTransportCreateCalls) Increment(string) {}
type noopTransportCAReloads struct{}
func (noopTransportCAReloads) Increment(result, reason string) {}

154
transport/ca_rotation.go Normal file
View File

@@ -0,0 +1,154 @@
/*
Copyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package transport
import (
"bytes"
"context"
"net/http"
"os"
"sync"
"time"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/client-go/tools/metrics"
"k8s.io/klog/v2"
"k8s.io/utils/clock"
)
var _ utilnet.RoundTripperWrapper = &atomicTransportHolder{}
// atomicTransportHolder holds a transport that can be atomically updated
// when CA files change, enabling graceful CA rotation without cache complexity
type atomicTransportHolder struct {
caFile string
currentCAData []byte // Track the actual CA data currently in use
// clock and caRefreshDuration are used to allow for testing time-based logic.
clock clock.Clock
caRefreshDuration time.Duration
// mu covers transport and transportLastChecked
mu sync.RWMutex
transport *http.Transport
transportLastChecked time.Time
}
func (h *atomicTransportHolder) RoundTrip(req *http.Request) (*http.Response, error) {
return h.getTransport(req.Context()).RoundTrip(req)
}
func (h *atomicTransportHolder) WrappedRoundTripper() http.RoundTripper {
h.mu.RLock()
defer h.mu.RUnlock()
return h.transport
}
func (h *atomicTransportHolder) getTransport(ctx context.Context) *http.Transport {
if rt := h.getTransportIfFresh(); rt != nil {
return rt
}
h.mu.Lock()
defer h.mu.Unlock()
h.tryRefreshTransportLocked(ctx)
return h.transport
}
func (h *atomicTransportHolder) getTransportIfFresh() *http.Transport {
h.mu.RLock()
defer h.mu.RUnlock()
if h.clock.Since(h.transportLastChecked) < h.caRefreshDuration {
return h.transport
}
return nil
}
func (h *atomicTransportHolder) tryRefreshTransportLocked(ctx context.Context) {
// If some other goroutine already checked/updated the CA
if h.clock.Since(h.transportLastChecked) < h.caRefreshDuration {
return
}
// only attempt CA reload once per caRefreshDuration, even if the reload fails
h.transportLastChecked = h.clock.Now()
logger := klog.FromContext(ctx).WithValues("caFile", h.caFile)
logger.V(4).Info("Checking CA file content")
// Load new CA data from file
newCAData, err := os.ReadFile(h.caFile)
// Return old transport on read error
if err != nil {
logger.Error(err, "Failed to read CA data from file")
metrics.TransportCAReloads.Increment("failure", "read_error")
return
}
if len(newCAData) == 0 {
logger.Info("CA file empty, skipping transport rotation")
metrics.TransportCAReloads.Increment("failure", "empty")
return
}
if bytes.Equal(h.currentCAData, newCAData) {
logger.V(4).Info("CA file unchanged, skipping transport rotation")
metrics.TransportCAReloads.Increment("success", "unchanged")
return
}
logger.V(4).Info("CA content changed, updating transport")
// Load new CA pool
newCAs, err := rootCertPool(newCAData)
// Return old transport on parse error
if err != nil {
logger.Error(err, "Failed to parse CA data from file")
metrics.TransportCAReloads.Increment("failure", "ca_parse_error")
return
}
newTransport := h.transport.Clone()
newTransport.TLSClientConfig.RootCAs = newCAs
oldTransport := h.transport
h.transport = newTransport
// Update our tracking of current CA data
h.currentCAData = newCAData
// Close idle connections on the old transport to encourage migration
oldTransport.CloseIdleConnections()
logger.V(4).Info("Transport updated for CA rotation")
metrics.TransportCAReloads.Increment("success", "updated")
}
// newAtomicTransportHolder creates a new holder for CA file reloading scenarios.
// The caFile must be specified.
// caData may be empty but should correspond to the contents of caFile.
// transport must have a TLS config and its root CAs should match caData.
func newAtomicTransportHolder(caFile string, caData []byte, transport *http.Transport) *atomicTransportHolder {
c := clock.RealClock{}
return &atomicTransportHolder{
caFile: caFile,
currentCAData: caData,
clock: c,
caRefreshDuration: 5 * time.Minute,
transport: transport,
transportLastChecked: c.Now(),
}
}

View File

@@ -0,0 +1,560 @@
/*
Copyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package transport
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"k8s.io/apimachinery/pkg/util/wait"
clientgofeaturegate "k8s.io/client-go/features"
clientfeaturestesting "k8s.io/client-go/features/testing"
"k8s.io/client-go/tools/metrics"
"k8s.io/client-go/util/cert"
testingclock "k8s.io/utils/clock/testing"
)
const (
// Use the same rootCACert as transport_test.go
testCACert1 = `-----BEGIN CERTIFICATE-----
MIIC4DCCAcqgAwIBAgIBATALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIxNTczN1oXDTE2MDExNTIxNTcz
OFowIzEhMB8GA1UEAwwYMTAuMTMuMTI5LjEwNkAxNDIxMzU5MDU4MIIBIjANBgkq
hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAunDRXGwsiYWGFDlWH6kjGun+PshDGeZX
xtx9lUnL8pIRWH3wX6f13PO9sktaOWW0T0mlo6k2bMlSLlSZgG9H6og0W6gLS3vq
s4VavZ6DbXIwemZG2vbRwsvR+t4G6Nbwelm6F8RFnA1Fwt428pavmNQ/wgYzo+T1
1eS+HiN4ACnSoDSx3QRWcgBkB1g6VReofVjx63i0J+w8Q/41L9GUuLqquFxu6ZnH
60vTB55lHgFiDLjA1FkEz2dGvGh/wtnFlRvjaPC54JH2K1mPYAUXTreoeJtLJKX0
ycoiyB24+zGCniUmgIsmQWRPaOPircexCp1BOeze82BT1LCZNTVaxQIDAQABoyMw
ITAOBgNVHQ8BAf8EBAMCAKQwDwYDVR0TAQH/BAUwAwEB/zALBgkqhkiG9w0BAQsD
ggEBADMxsUuAFlsYDpF4fRCzXXwrhbtj4oQwcHpbu+rnOPHCZupiafzZpDu+rw4x
YGPnCb594bRTQn4pAu3Ac18NbLD5pV3uioAkv8oPkgr8aUhXqiv7KdDiaWm6sbAL
EHiXVBBAFvQws10HMqMoKtO8f1XDNAUkWduakR/U6yMgvOPwS7xl0eUTqyRB6zGb
K55q2dejiFWaFqB/y78txzvz6UlOZKE44g2JAVoJVM6kGaxh33q8/FmrL4kuN3ut
W+MmJCVDvd4eEqPwbp7146ZWTqpIJ8lvA6wuChtqV8lhAPka2hD/LMqY8iXNmfXD
uml0obOEy+ON91k+SWTJ3ggmF/U=
-----END CERTIFICATE-----`
// A different CA cert for testing rotation (modified version of certData from transport_test.go)
testCACert2 = `-----BEGIN CERTIFICATE-----
MIIC6jCCAdSgAwIBAgIBCzALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIyMDEzMVoXDTE2MDExNTIyMDEz
MlowGzEZMBcGA1UEAxMQb3BlbnNoaWZ0LWNsaWVudDCCASIwDQYJKoZIhvcNAQEB
BQADggEPADCCAQoCggEBAKtdhz0+uCLXw5cSYns9rU/XifFSpb/x24WDdrm72S/v
b9BPYsAStiP148buylr1SOuNi8sTAZmlVDDIpIVwMLff+o2rKYDicn9fjbrTxTOj
lI4pHJBH+JU3AJ0tbajupioh70jwFS0oYpwtneg2zcnE2Z4l6mhrj2okrc5Q1/X2
I2HChtIU4JYTisObtin10QKJX01CLfYXJLa8upWzKZ4/GOcHG+eAV3jXWoXidtjb
1Usw70amoTZ6mIVCkiu1QwCoa8+ycojGfZhvqMsAp1536ZcCul+Na+AbCv4zKS7F
kQQaImVrXdUiFansIoofGlw/JNuoKK6ssVpS5Ic3pgcCAwEAAaM1MDMwDgYDVR0P
AQH/BAQDAgCgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMAwGA1UdEwEB/wQCMAAwCwYJ
KoZIhvcNAQELA4IBAQCKLREH7bXtXtZ+8vI6cjD7W3QikiArGqbl36bAhhWsJLp/
p/ndKz39iFNaiZ3GlwIURWOOKx3y3GA0x9m8FR+Llthf0EQ8sUjnwaknWs0Y6DQ3
jjPFZOpV3KPCFrdMJ3++E3MgwFC/Ih/N2ebFX9EcV9Vcc6oVWMdwT0fsrhu683rq
6GSR/3iVX1G/pmOiuaR0fNUaCyCfYrnI4zHBDgSfnlm3vIvN2lrsR/DQBakNL8DJ
HBgKxMGeUPoneBv+c8DMXIL0EhaFXRlBv9QW45/GiAIOuyFJ0i6hCtGZpJjq4OpQ
BRjCI+izPzFTjsxD4aORE+WOkyWFCGPWKfNejfw0
-----END CERTIFICATE-----`
)
// writeCAFile writes CA data to a temporary file
func writeCAFile(t testing.TB, caData []byte) string {
tmpDir := t.TempDir()
caFile := filepath.Join(tmpDir, "ca.crt")
err := os.WriteFile(caFile, caData, 0644)
if err != nil {
t.Fatalf("Failed to write CA file: %v", err)
}
t.Cleanup(func() {
if err := os.Remove(caFile); err != nil {
t.Logf("unexpected error while removing file: %s - %v", caFile, err)
}
})
return caFile
}
// createTestTransport creates a test transport with TLS config
func createTestTransport(t testing.TB, caData []byte) *http.Transport {
CAs, err := rootCertPool(caData)
if err != nil {
t.Fatalf("Failed to parse CA certificate")
}
return &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: CAs,
},
}
}
func TestCheckCAFileAndRotate(t *testing.T) {
tests := []struct {
name string
setupCA []byte
updateCA []byte
caFileOverride string
expectRotation bool
}{
{
name: "no change",
setupCA: []byte(testCACert1),
updateCA: []byte(testCACert1), // Same CA
expectRotation: false,
},
{
name: "CA changed",
setupCA: []byte(testCACert1),
updateCA: []byte(testCACert2), // Different CA
expectRotation: true,
},
{
name: "CA changed to invalid",
setupCA: []byte(testCACert1),
updateCA: []byte("panda"), // invalid CA
expectRotation: false,
},
{
name: "file error",
setupCA: []byte(testCACert1),
caFileOverride: "/nonexistent/ca.crt", // Non-existent file
expectRotation: false,
},
{
name: "empty file content",
setupCA: []byte(testCACert1),
updateCA: []byte{}, // Empty file
expectRotation: false,
},
{
name: "initially empty CA file updated to valid CA",
setupCA: []byte{}, // Starts with empty CA
updateCA: []byte(testCACert1), // Populated with a valid CA
expectRotation: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
caFile := writeCAFile(t, tt.setupCA)
if len(tt.caFileOverride) > 0 {
caFile = tt.caFileOverride
}
transport := createTestTransport(t, tt.setupCA)
setupRoots := transport.TLSClientConfig.RootCAs.Clone()
expectedRoots := setupRoots
if tt.expectRotation {
var err error
expectedRoots, err = rootCertPool(tt.updateCA)
if err != nil {
t.Fatal(err)
}
}
clock := testingclock.NewFakeClock(time.Now())
holder := newAtomicTransportHolder(caFile, tt.setupCA, transport)
holder.clock = clock
holder.transportLastChecked = clock.Now()
if tt.updateCA != nil {
// Update the file with new CA content
err := os.WriteFile(caFile, tt.updateCA, 0644)
if err != nil {
t.Errorf("Failed to update CA data with file address: %s", caFile)
}
}
clock.Step(holder.caRefreshDuration)
// Check CA file rotation
newTransport := holder.getTransport(t.Context())
newRoots := newTransport.TLSClientConfig.RootCAs
if newRoots == nil || !expectedRoots.Equal(newRoots) {
t.Error("new roots did not match expected roots")
}
transportRotated := newTransport != transport
if tt.expectRotation != transportRotated {
t.Error("transport rotation did not match")
}
})
}
}
func generateServerCertAndCA(t testing.TB) (servingCertPEM, servingKeyPEM, caCertPEM []byte) {
t.Helper()
certPEM, keyPEM, err := cert.GenerateSelfSignedCertKey("127.0.0.1", nil, nil)
if err != nil {
t.Fatalf("Failed to generate server cert: %v", err)
}
certs, err := cert.ParseCertsPEM(certPEM)
if err != nil || len(certs) < 2 {
t.Fatal("Expected cert chain with [leaf, CA]")
}
caPEM, err := cert.EncodeCertificates(certs[len(certs)-1])
if err != nil {
t.Fatalf("Failed to encode CA cert: %v", err)
}
return certPEM, keyPEM, caPEM
}
// TestCARotationConnectionBehavior tests end-to-end CA rotation:
// 1. Client trusts server CA via a CA file on disk
// 2. Server rotates to a new CA + serving cert
// 3. Client fails (doesn't trust new CA yet)
// 4. CA file updated, transport reloads, client reconnects
func TestCARotationConnectionBehavior(t *testing.T) {
t.Log("Testing CA Rotation Connection Behavior")
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
// Generate initial server v1 cert and CA
servingCertPEM1, servingKeyPEM1, caCertPEM1 := generateServerCertAndCA(t)
// Start server v1 (no client cert required - test focuses on server CA rotation)
srv1 := newTestServer(t, servingCertPEM1, servingKeyPEM1)
srv1.StartTLS()
defer srv1.Close()
// Set up the client
clientCAFile := writeCAFile(t, caCertPEM1)
config := &Config{
TLS: TLSConfig{
CAFile: clientCAFile,
},
}
transport, err := New(config)
if err != nil {
t.Fatalf("Failed to create transport: %v", err)
}
transport.(*atomicTransportHolder).caRefreshDuration = 500 * time.Millisecond
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
// Initial connection must succeed
t.Log("Making initial request to server v1, expecting success...")
resp, err := client.Get(srv1.URL)
if err != nil {
t.Fatalf("Failed to call the server v1: %v", err)
}
if err := resp.Body.Close(); err != nil {
t.Fatal("Failed to close the response.")
}
if resp.StatusCode != http.StatusOK {
t.Fatal("Failed to call the server successfully.")
}
t.Log("Initial connection successful.")
// Rotate: new CA, new server cert, same address
t.Log("Stopping server v1 and starting server v2 with new CA...")
srv1Addr := srv1.Listener.Addr().String()
srv1.Close()
servingCertPEM2, servingKeyPEM2, caCertPEM2 := generateServerCertAndCA(t)
srv2 := newTestServer(t, servingCertPEM2, servingKeyPEM2)
l, err := net.Listen("tcp", srv1Addr)
if err != nil {
t.Fatalf("Failed to re-claim the same server address: %v", err)
}
srv2.Listener = l
srv2.StartTLS()
defer srv2.Close()
// Must fail - client still trusts old CA
t.Log("Making request to server v2, expecting failure...")
_, err = client.Get(srv2.URL)
if err == nil {
t.Fatal("The request should fail.")
}
t.Log("Request failed as expected.")
// Update CA file to trust new CA
t.Log("Updating client CA file on disk to trust new CA...")
if err := os.WriteFile(clientCAFile, caCertPEM2, 0644); err != nil {
t.Fatalf("Failed to update CA file: %v", err)
}
// Poll until transport reloads
t.Log("Polling server v2 until the client's transport reloads the new CA...")
var lastPollErr error
ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer ctxCancel()
pollErr := wait.PollUntilContextCancel(ctx, 500*time.Millisecond, true, func(ctx context.Context) (bool, error) {
resp, err := client.Get(srv2.URL)
if err != nil {
lastPollErr = err
t.Log("Client failed to connect before the root CAs are updated, will retry...")
return false, nil // Error is expected, continue polling
}
if err := resp.Body.Close(); err != nil {
t.Fatal("Failed to close the response.")
}
if resp.StatusCode == http.StatusOK {
return true, nil // Success! Stop polling.
}
return false, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
})
if pollErr != nil {
t.Fatalf("Client failed to reconnect after CA rotation. Last error: %v. Test error: %v", lastPollErr, pollErr)
}
t.Log("Success! Client reconnected after CA was refreshed.")
}
func TestCARotationConnectionBehavior_Disabled(t *testing.T) {
t.Log("Testing CA Rotation Connection Behavior (Feature Disabled)")
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false)
// Generate initial server v1 cert and CA
servingCertPEM1, servingKeyPEM1, caCertPEM1 := generateServerCertAndCA(t)
// Start server v1 (no client cert required - test focuses on server CA rotation)
srv1 := newTestServer(t, servingCertPEM1, servingKeyPEM1)
srv1.StartTLS()
defer srv1.Close()
// Set up the client
clientCAFile := writeCAFile(t, caCertPEM1)
config := &Config{
TLS: TLSConfig{
CAFile: clientCAFile,
},
}
transport, err := New(config)
if err != nil {
t.Fatalf("Failed to create transport: %v", err)
}
if _, ok := transport.(*atomicTransportHolder); ok {
t.Fatal("Expected plain transport when the feature gate is disabled, got atomicTransportHolder")
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
// Initial connection must succeed
t.Log("Making initial request to server v1, expecting success...")
resp, err := client.Get(srv1.URL)
if err != nil {
t.Fatalf("Failed to call the server v1: %v", err)
}
if err := resp.Body.Close(); err != nil {
t.Fatal("Failed to close the response.")
}
if resp.StatusCode != http.StatusOK {
t.Fatal("Failed to call the server successfully.")
}
t.Log("Initial connection successful.")
// Rotate: new CA, new server cert, same address
t.Log("Stopping server v1 and starting server v2 with new CA...")
srv1Addr := srv1.Listener.Addr().String()
srv1.Close()
servingCertPEM2, servingKeyPEM2, caCertPEM2 := generateServerCertAndCA(t)
srv2 := newTestServer(t, servingCertPEM2, servingKeyPEM2)
l, err := net.Listen("tcp", srv1Addr)
if err != nil {
t.Fatalf("Failed to re-claim the same server address: %v", err)
}
srv2.Listener = l
srv2.StartTLS()
defer srv2.Close()
// Must fail - client still trusts old CA
t.Log("Making request to server v2, expecting failure...")
_, err = client.Get(srv2.URL)
if err == nil {
t.Fatal("The request should fail.")
}
t.Log("Request failed as expected.")
// Update CA file to trust new CA
t.Log("Updating client CA file on disk to trust new CA...")
if err := os.WriteFile(clientCAFile, caCertPEM2, 0644); err != nil {
t.Fatalf("Failed to update CA file: %v", err)
}
// Poll to ensure transport DOES NOT reload
t.Log("Polling server v2 to verify the client DOES NOT reconnect...")
var lastPollErr error
ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer ctxCancel()
pollErr := wait.PollUntilContextCancel(ctx, 500*time.Millisecond, true, func(ctx context.Context) (bool, error) {
resp, err := client.Get(srv2.URL)
if err != nil {
lastPollErr = err
t.Log("Client failed to connect (expected because feature is disabled)...")
return false, nil // Keep polling until the timeout is reached
}
if err := resp.Body.Close(); err != nil {
t.Fatal("Failed to close the response.")
}
if resp.StatusCode == http.StatusOK {
return true, nil // Success! But this means the test failed.
}
return false, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
})
// If pollErr is nil, it means the connection succeeded, which is wrong for this test.
if pollErr == nil {
t.Fatalf("Client unexpectedly reconnected after CA rotation! The feature gate is disabled, so this should have failed.")
}
t.Logf("Success! Client permanently failed to reconnect as expected. Last error: %v", lastPollErr)
}
// testCAReloadsMetric is a fake metric recorder that records calls to Increment.
type testCAReloadsMetric struct {
calls []caReloadCall
}
type caReloadCall struct {
result, reason string
}
func (m *testCAReloadsMetric) Increment(result, reason string) {
m.calls = append(m.calls, caReloadCall{result, reason})
}
// TestCARotationMetricsEmitted verifies that ca_rotation.go emits the correct
// metrics during actual CA reload operations.
func TestCARotationMetricsEmitted(t *testing.T) {
fakeMetricRecorder := &testCAReloadsMetric{}
origMetric := metrics.TransportCAReloads
metrics.TransportCAReloads = fakeMetricRecorder
t.Cleanup(func() { metrics.TransportCAReloads = origMetric })
caData := []byte(testCACert1)
caFile := writeCAFile(t, caData)
transport := createTestTransport(t, caData)
clock := testingclock.NewFakeClock(time.Now())
holder := newAtomicTransportHolder(caFile, caData, transport)
holder.clock = clock
holder.transportLastChecked = clock.Now()
tests := []struct {
name string
setup func() error
wantResult string
wantReason string
}{
{
name: "unchanged CA",
setup: func() error { return nil },
wantResult: "success",
wantReason: "unchanged",
},
{
name: "updated CA",
setup: func() error {
return os.WriteFile(caFile, []byte(testCACert2), 0644)
},
wantResult: "success",
wantReason: "updated",
},
{
name: "empty file",
setup: func() error {
return os.WriteFile(caFile, []byte{}, 0644)
},
wantResult: "failure",
wantReason: "empty",
},
{
name: "invalid CA data",
setup: func() error {
return os.WriteFile(caFile, []byte("not-a-cert"), 0644)
},
wantResult: "failure",
wantReason: "ca_parse_error",
},
{
name: "read error",
setup: func() error {
return os.Remove(caFile)
},
wantResult: "failure",
wantReason: "read_error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fakeMetricRecorder.calls = nil
err := tt.setup()
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
clock.Step(holder.caRefreshDuration)
holder.getTransport(context.Background())
if len(fakeMetricRecorder.calls) != 1 {
t.Fatalf("Expected 1 metric call, got %d", len(fakeMetricRecorder.calls))
}
if fakeMetricRecorder.calls[0].result != tt.wantResult || fakeMetricRecorder.calls[0].reason != tt.wantReason {
t.Errorf("Got metric(%s, %s), want (%s, %s)",
fakeMetricRecorder.calls[0].result, fakeMetricRecorder.calls[0].reason, tt.wantResult, tt.wantReason)
}
})
}
}
// helper to create a simple, non-blocking test server with a given certificate.
func newTestServer(t *testing.T, certPEM, keyPEM []byte) *httptest.Server {
t.Helper()
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if _, err := fmt.Fprint(w, "ok"); err != nil {
t.Fatal("Failed to write to the response.")
}
}))
// Configure server TLS
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("Failed to create server cert: %v", err)
}
server.TLS = &tls.Config{
Certificates: []tls.Certificate{tlsCert},
}
return server
}

View File

@@ -36,7 +36,7 @@ import (
// the config has no custom TLS options, http.DefaultTransport is returned.
type tlsTransportCache struct {
mu sync.Mutex
transports map[tlsCacheKey]*http.Transport
transports map[tlsCacheKey]http.RoundTripper
}
// DialerStopCh is stop channel that is passed down to dynamic cert dialer.
@@ -46,11 +46,14 @@ var DialerStopCh = wait.NeverStop
const idleConnsPerHost = 25
var tlsCache = &tlsTransportCache{transports: make(map[tlsCacheKey]*http.Transport)}
var tlsCache = &tlsTransportCache{
transports: make(map[tlsCacheKey]http.RoundTripper),
}
type tlsCacheKey struct {
insecure bool
caData string
caFile string
certData string
keyData string `datapolicy:"security-key"`
certFile string
@@ -68,8 +71,8 @@ func (t tlsCacheKey) String() string {
if len(t.keyData) > 0 {
keyText = "<redacted>"
}
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
return fmt.Sprintf("insecure:%v, caData:%#v, caFile:%s, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
t.insecure, t.caData, t.caFile, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
}
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
@@ -131,7 +134,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
proxy = config.Proxy
}
transport := utilnet.SetTransportDefaults(&http.Transport{
httpTransport := utilnet.SetTransportDefaults(&http.Transport{
Proxy: proxy,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
@@ -139,6 +142,11 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
DialContext: dial,
DisableCompression: config.DisableCompression,
})
var transport http.RoundTripper = httpTransport
if config.TLS.ReloadCAFiles && tlsConfig != nil && tlsConfig.RootCAs != nil && len(config.TLS.CAFile) > 0 {
transport = newAtomicTransportHolder(config.TLS.CAFile, config.TLS.CAData, httpTransport)
}
if canCache {
// Cache a single transport for these options
@@ -162,7 +170,6 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
k := tlsCacheKey{
insecure: c.TLS.Insecure,
caData: string(c.TLS.CAData),
serverName: c.TLS.ServerName,
nextProtos: strings.Join(c.TLS.NextProtos, ","),
disableCompression: c.DisableCompression,
@@ -178,5 +185,14 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
k.keyData = string(c.TLS.KeyData)
}
if c.TLS.ReloadCAFiles {
// When reloading CA files, include CA file path in cache key instead of CA data
// This allows the CA to be reloaded from disk on each transport creation
k.caFile = c.TLS.CAFile
} else {
// When not reloading, cache the CA data directly
k.caData = string(c.TLS.CAData)
}
return k, true, nil
}

View File

@@ -19,14 +19,22 @@ package transport
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"net/http"
"net/url"
"os"
"reflect"
"testing"
"time"
clientgofeaturegate "k8s.io/client-go/features"
clientfeaturestesting "k8s.io/client-go/features/testing"
)
func TestTLSConfigKey(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
// Make sure config fields that don't affect the tls config don't affect the cache key
identicalConfigurations := map[string]*Config{
"empty": {},
@@ -70,14 +78,17 @@ func TestTLSConfigKey(t *testing.T) {
// Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
getCert := &GetCertHolder{GetCert: func() (*tls.Certificate, error) { return nil, nil }}
caFile := writeCAFile(t, []byte(testCACert1))
uniqueConfigurations := map[string]*Config{
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
"no tls": {},
"dialer": {DialHolder: &DialHolder{Dial: dialer.DialContext}},
"dialer2": {DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
"no tls": {},
"dialer": {DialHolder: &DialHolder{Dial: dialer.DialContext}},
"dialer2": {DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
"with only ca file": {TLS: TLSConfig{CAFile: caFile}},
"with both ca file and ca data": {TLS: TLSConfig{CAFile: caFile, CAData: []byte(testCACert1)}},
"cert 1, key 1": {
TLS: TLSConfig{
CertData: []byte{1},
@@ -188,3 +199,228 @@ func TestTLSConfigKey(t *testing.T) {
}
}
}
func TestTLSConfigKeyCARotationDisabled(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false)
caFile := writeCAFile(t, []byte(testCACert1))
// When feature is disabled, CAFile-only config resolves CAData via
// loadTLSFiles, so two configs with the same file content get the same key.
config1 := &Config{TLS: TLSConfig{CAFile: caFile}}
config2 := &Config{TLS: TLSConfig{CAData: []byte(testCACert1)}}
if err := loadTLSFiles(config1); err != nil {
t.Fatal(err)
}
if err := loadTLSFiles(config2); err != nil {
t.Fatal(err)
}
key1, canCache1, err := tlsConfigKey(config1)
if err != nil || !canCache1 {
t.Fatalf("unexpected: err=%v, canCache=%v", err, canCache1)
}
key2, canCache2, err := tlsConfigKey(config2)
if err != nil || !canCache2 {
t.Fatalf("unexpected: err=%v, canCache=%v", err, canCache2)
}
if key1 != key2 {
t.Error("Expected same cache key when feature is disabled (CAFile resolved to CAData)")
}
if config1.TLS.ReloadCAFiles {
t.Error("Expected ReloadCAFiles=false when feature gate is disabled")
}
}
// TestTLSTransportCacheCARotation tests transport cache behavior with CA rotation
func TestTLSTransportCacheCARotation(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
caFile := writeCAFile(t, []byte(testCACert1))
testCases := []struct {
name string
config *Config
expectWrapper bool
expectCacheable bool
}{
{
name: "CA rotation should be enabled when only the CAFile is set",
config: &Config{
TLS: TLSConfig{
CAFile: caFile,
},
},
expectWrapper: true,
expectCacheable: true,
},
{
name: "CA rotation should be disabled when both CAFile and CAData are set",
config: &Config{
TLS: TLSConfig{
CAFile: caFile,
CAData: []byte(testCACert1),
},
},
expectWrapper: false,
expectCacheable: true,
},
{
name: "CA rotation should be disabled when only the CAData is set",
config: &Config{
TLS: TLSConfig{
CAData: []byte(testCACert1),
},
},
expectWrapper: false,
expectCacheable: true,
},
{
name: "no TLS config",
config: &Config{
TLS: TLSConfig{},
},
expectWrapper: false,
expectCacheable: false, // No TLS config means default transport
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create new cache for testing
tlsCaches := &tlsTransportCache{
transports: make(map[tlsCacheKey]http.RoundTripper),
}
rt, err := tlsCaches.get(tc.config)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !tc.expectCacheable {
// Should return default transport
if rt != http.DefaultTransport {
t.Errorf("Expected default transport, got %T", rt)
}
return
}
if tc.expectWrapper {
// Should be wrapped in atomicTransportHolder
if _, ok := rt.(*atomicTransportHolder); !ok {
t.Errorf("Expected atomicTransportHolder, got %T", rt)
}
if !tc.config.TLS.ReloadCAFiles {
t.Errorf("Expected ReloadCAFiles to be true, got %v", tc.config.TLS.ReloadCAFiles)
}
} else {
// Should be a regular http.Transport
if _, ok := rt.(*http.Transport); !ok {
t.Errorf("Expected *http.Transport, got %T", rt)
}
}
// Test caching: second call should return the same instance
rt2, err := tlsCaches.get(tc.config)
if err != nil {
t.Fatalf("Unexpected error on second call: %v", err)
}
if rt != rt2 {
t.Error("Expected same transport instance from cache")
}
// Verify cache size
tlsCaches.mu.Lock()
cacheSize := len(tlsCaches.transports)
tlsCaches.mu.Unlock()
expectedCacheSize := 1
if !tc.expectCacheable {
expectedCacheSize = 0
}
if cacheSize != expectedCacheSize {
t.Errorf("Expected %d transports in cache, got %d", expectedCacheSize, cacheSize)
}
})
}
}
func TestTLSTransportCacheCARotationDisabled(t *testing.T) {
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false)
caFile := writeCAFile(t, []byte(testCACert1))
cache := &tlsTransportCache{transports: make(map[tlsCacheKey]http.RoundTripper)}
rt, err := cache.get(&Config{TLS: TLSConfig{CAFile: caFile}})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if _, ok := rt.(*atomicTransportHolder); ok {
t.Error("Expected plain *http.Transport when feature gate is disabled, got atomicTransportHolder")
}
if _, ok := rt.(*http.Transport); !ok {
t.Errorf("Expected *http.Transport, got %T", rt)
}
}
func TestEmptyCAFileRotationLifecycle(t *testing.T) {
// Enable the feature gate for the duration of the test
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
// Create a valid file path, but with empty cert data
emptyFile := writeCAFile(t, []byte{})
config := &Config{
TLS: TLSConfig{
CAFile: emptyFile,
},
}
tlsCaches := &tlsTransportCache{
transports: make(map[tlsCacheKey]http.RoundTripper),
}
rt, err := tlsCaches.get(config)
if err != nil {
t.Fatalf("Unexpected error getting transport: %v", err)
}
// Verify newAtomicTransportHolder is successfully generated
holder, ok := rt.(*atomicTransportHolder)
if !ok {
t.Fatalf("Expected atomicTransportHolder, got %T", rt)
}
// Verify the initial state: RootCAs should be non-nil but empty
initialTransport := holder.getTransport(context.Background())
if initialTransport.TLSClientConfig == nil || initialTransport.TLSClientConfig.RootCAs == nil {
t.Fatal("Expected RootCAs to be non-nil for an empty CA file (should be an empty CertPool)")
}
emptyPool := x509.NewCertPool()
if !initialTransport.TLSClientConfig.RootCAs.Equal(emptyPool) {
t.Fatal("Expected initially empty RootCAs")
}
// Write valid cert data into the CA file
if err := os.WriteFile(emptyFile, []byte(testCACert1), 0644); err != nil {
t.Fatalf("Failed to write to CA file: %v", err)
}
holder.mu.Lock()
// Set last checked time far in the past to force a refresh on next getTransport call
holder.transportLastChecked = time.Now().Add(-time.Hour)
holder.mu.Unlock()
refreshedTransport := holder.getTransport(context.Background())
// Verify the refresh succeeded and the cert pool is now populated
if refreshedTransport.TLSClientConfig.RootCAs.Equal(emptyPool) {
t.Fatal("Expected RootCAs to be populated after writing valid cert data and refreshing")
}
}

View File

@@ -134,7 +134,8 @@ type TLSConfig struct {
CAFile string // Path of the PEM-encoded server trusted root certificates.
CertFile string // Path of the PEM-encoded client certificate.
KeyFile string // Path of the PEM-encoded client key.
ReloadTLSFiles bool // Set to indicate that the original config provided files, and that they should be reloaded
ReloadTLSFiles bool // Set to indicate that the original config provided files, and that they should be reloaded.
ReloadCAFiles bool // Set to indicate that CA files should be reloaded from disk.
Insecure bool // Server should be accessed without verifying the certificate. For testing only.
ServerName string // Override for the server name passed to the server for SNI and used to verify certificates.

View File

@@ -28,6 +28,7 @@ import (
"time"
utilnet "k8s.io/apimachinery/pkg/util/net"
clientgofeaturegate "k8s.io/client-go/features"
"k8s.io/klog/v2"
)
@@ -211,17 +212,26 @@ func TLSConfigFor(c *Config) (*tls.Config, error) {
// KeyData, and CAFile fields, or returns an error. If no error is returned, all three fields are
// either populated or were empty to start.
func loadTLSFiles(c *Config) error {
// Check that we are purely loading CA from file
if clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowCARotation) {
if len(c.TLS.CAFile) > 0 && len(c.TLS.CAData) == 0 {
c.TLS.ReloadCAFiles = true
}
} else if c.TLS.ReloadCAFiles {
return fmt.Errorf("ReloadCAFiles=true requires ClientsAllowCARotation to be enabled")
}
// Check that we are purely loading certs and keys from files
if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 {
c.TLS.ReloadTLSFiles = true
}
var err error
c.TLS.CAData, err = dataFromSliceOrFile(c.TLS.CAData, c.TLS.CAFile)
if err != nil {
return err
}
// Check that we are purely loading from files
if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 {
c.TLS.ReloadTLSFiles = true
}
c.TLS.CertData, err = dataFromSliceOrFile(c.TLS.CertData, c.TLS.CertFile)
if err != nil {
return err
@@ -254,6 +264,11 @@ func rootCertPool(caData []byte) (*x509.CertPool, error) {
// code for a look at the platform specific insanity), so we'll use the fact that RootCAs == nil gives us the system values
// It doesn't allow trusting either/or, but hopefully that won't be an issue
if len(caData) == 0 {
// When the ClientsAllowCARotation feature gate is enabled, it returns an empty but non-nil pool.
// This ensures we don't fall back to system roots when a user explicitly points CAFile to a zero-byte file.
if clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowCARotation) {
return x509.NewCertPool(), nil
}
return nil, nil
}

View File

@@ -19,11 +19,15 @@ package transport
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/http"
"testing"
clientgofeaturegate "k8s.io/client-go/features"
clientfeaturestesting "k8s.io/client-go/features/testing"
)
const (
@@ -356,6 +360,16 @@ func TestNew(t *testing.T) {
}
for k, testCase := range testCases {
t.Run(k, func(t *testing.T) {
// The Close method of httptest Server mutates the
// `http.DefaultTransport` object, the 'TLSClientConfig'
// field mutates from nil to a non nil instance. This introduces flake
// and data race when running tests under transport package in parallel.
// To work around it we reset the TLSClientConfig field.
//
// See: https://github.com/golang/go/issues/65796
if testCase.Default {
http.DefaultTransport.(*http.Transport).TLSClientConfig = nil
}
rt, err := New(testCase.Config)
switch {
case testCase.Err && err == nil:
@@ -556,3 +570,51 @@ func Test_contextCanceller_RoundTrip(t *testing.T) {
})
}
}
func TestRootCertPoolEmptyData(t *testing.T) {
testCases := []struct {
name string
featureGateEnabled bool
expectNilPool bool
}{
{
name: "feature gate disabled returns nil (system roots)",
featureGateEnabled: false,
expectNilPool: true,
},
{
name: "feature gate enabled returns empty pool (trust nothing)",
featureGateEnabled: true,
expectNilPool: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Set the feature gate according to the test case
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, tc.featureGateEnabled)
// Call the function with empty caData
pool, err := rootCertPool([]byte{})
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if tc.expectNilPool {
if pool != nil {
t.Fatalf("Expected pool to be nil when feature gate is disabled, but got a populated pool")
}
} else {
if pool == nil {
t.Fatalf("Expected pool to be non-nil (empty pool) when feature gate is enabled, but got nil")
}
// Verify it is truly an empty pool
emptyPool := x509.NewCertPool()
if !pool.Equal(emptyPool) {
t.Fatalf("Expected the returned pool to be completely empty")
}
}
})
}
}