mirror of
https://github.com/kubernetes/client-go.git
synced 2026-05-14 19:17:56 +00:00
client-go CA rotation
Kubernetes-commit: b806a3207beb1afabc3d3d93fbadba38ffb8a110
This commit is contained in:
committed by
Kubernetes Publisher
parent
6f2c112d22
commit
8776b282cc
@@ -38,6 +38,13 @@ const (
|
||||
// events for items popped off the FIFO.
|
||||
AtomicFIFO Feature = "AtomicFIFO"
|
||||
|
||||
// owner: @yt2985
|
||||
// beta: 1.36
|
||||
//
|
||||
// If enabled, allows clients to gracefully handle Certificate Authority (CA)
|
||||
// rotations without dropping connections or requiring a restart.
|
||||
ClientsAllowCARotation Feature = "ClientsAllowCARotation"
|
||||
|
||||
// owner: @benluddy
|
||||
// kep: https://kep.k8s.io/4222
|
||||
// alpha: 1.32
|
||||
@@ -98,6 +105,9 @@ var defaultVersionedKubernetesFeatureGates = map[Feature]VersionedSpecs{
|
||||
AtomicFIFO: {
|
||||
{Version: version.MustParse("1.36"), Default: true, PreRelease: Beta},
|
||||
},
|
||||
ClientsAllowCARotation: {
|
||||
{Version: version.MustParse("1.36"), Default: true, PreRelease: Beta},
|
||||
},
|
||||
ClientsAllowCBOR: {
|
||||
{Version: version.MustParse("1.32"), Default: false, PreRelease: Alpha},
|
||||
},
|
||||
|
||||
@@ -85,6 +85,12 @@ type TransportCreateCallsMetric interface {
|
||||
Increment(result string)
|
||||
}
|
||||
|
||||
// TransportCAReloadsMetric counts the number of times a CA reload is attempted,
|
||||
// partitioned by the result and reason.
|
||||
type TransportCAReloadsMetric interface {
|
||||
Increment(result, reason string)
|
||||
}
|
||||
|
||||
var (
|
||||
// ClientCertExpiry is the expiry time of a client certificate
|
||||
ClientCertExpiry ExpiryMetric = noopExpiry{}
|
||||
@@ -117,6 +123,9 @@ var (
|
||||
// TransportCreateCalls is the metric that counts the number of times a new transport
|
||||
// is created
|
||||
TransportCreateCalls TransportCreateCallsMetric = noopTransportCreateCalls{}
|
||||
|
||||
// TransportCAReloads is the metric that counts the number of times a CA reload is attempted
|
||||
TransportCAReloads TransportCAReloadsMetric = noopTransportCAReloads{}
|
||||
)
|
||||
|
||||
// RegisterOpts contains all the metrics to register. Metrics may be nil.
|
||||
@@ -134,6 +143,7 @@ type RegisterOpts struct {
|
||||
RequestRetry RetryMetric
|
||||
TransportCacheEntries TransportCacheMetric
|
||||
TransportCreateCalls TransportCreateCallsMetric
|
||||
TransportCAReloads TransportCAReloadsMetric
|
||||
}
|
||||
|
||||
// Register registers metrics for the rest client to use. This can
|
||||
@@ -179,6 +189,9 @@ func Register(opts RegisterOpts) {
|
||||
if opts.TransportCreateCalls != nil {
|
||||
TransportCreateCalls = opts.TransportCreateCalls
|
||||
}
|
||||
if opts.TransportCAReloads != nil {
|
||||
TransportCAReloads = opts.TransportCAReloads
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -226,3 +239,7 @@ func (noopTransportCache) Observe(int) {}
|
||||
type noopTransportCreateCalls struct{}
|
||||
|
||||
func (noopTransportCreateCalls) Increment(string) {}
|
||||
|
||||
type noopTransportCAReloads struct{}
|
||||
|
||||
func (noopTransportCAReloads) Increment(result, reason string) {}
|
||||
|
||||
154
transport/ca_rotation.go
Normal file
154
transport/ca_rotation.go
Normal file
@@ -0,0 +1,154 @@
|
||||
/*
|
||||
Copyright The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
utilnet "k8s.io/apimachinery/pkg/util/net"
|
||||
"k8s.io/client-go/tools/metrics"
|
||||
"k8s.io/klog/v2"
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
var _ utilnet.RoundTripperWrapper = &atomicTransportHolder{}
|
||||
|
||||
// atomicTransportHolder holds a transport that can be atomically updated
|
||||
// when CA files change, enabling graceful CA rotation without cache complexity
|
||||
type atomicTransportHolder struct {
|
||||
caFile string
|
||||
currentCAData []byte // Track the actual CA data currently in use
|
||||
// clock and caRefreshDuration are used to allow for testing time-based logic.
|
||||
clock clock.Clock
|
||||
caRefreshDuration time.Duration
|
||||
// mu covers transport and transportLastChecked
|
||||
mu sync.RWMutex
|
||||
transport *http.Transport
|
||||
transportLastChecked time.Time
|
||||
}
|
||||
|
||||
func (h *atomicTransportHolder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return h.getTransport(req.Context()).RoundTrip(req)
|
||||
}
|
||||
|
||||
func (h *atomicTransportHolder) WrappedRoundTripper() http.RoundTripper {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
return h.transport
|
||||
}
|
||||
|
||||
func (h *atomicTransportHolder) getTransport(ctx context.Context) *http.Transport {
|
||||
if rt := h.getTransportIfFresh(); rt != nil {
|
||||
return rt
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.tryRefreshTransportLocked(ctx)
|
||||
return h.transport
|
||||
}
|
||||
|
||||
func (h *atomicTransportHolder) getTransportIfFresh() *http.Transport {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if h.clock.Since(h.transportLastChecked) < h.caRefreshDuration {
|
||||
return h.transport
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *atomicTransportHolder) tryRefreshTransportLocked(ctx context.Context) {
|
||||
// If some other goroutine already checked/updated the CA
|
||||
if h.clock.Since(h.transportLastChecked) < h.caRefreshDuration {
|
||||
return
|
||||
}
|
||||
|
||||
// only attempt CA reload once per caRefreshDuration, even if the reload fails
|
||||
h.transportLastChecked = h.clock.Now()
|
||||
|
||||
logger := klog.FromContext(ctx).WithValues("caFile", h.caFile)
|
||||
|
||||
logger.V(4).Info("Checking CA file content")
|
||||
|
||||
// Load new CA data from file
|
||||
newCAData, err := os.ReadFile(h.caFile)
|
||||
// Return old transport on read error
|
||||
if err != nil {
|
||||
logger.Error(err, "Failed to read CA data from file")
|
||||
metrics.TransportCAReloads.Increment("failure", "read_error")
|
||||
return
|
||||
}
|
||||
|
||||
if len(newCAData) == 0 {
|
||||
logger.Info("CA file empty, skipping transport rotation")
|
||||
metrics.TransportCAReloads.Increment("failure", "empty")
|
||||
return
|
||||
}
|
||||
|
||||
if bytes.Equal(h.currentCAData, newCAData) {
|
||||
logger.V(4).Info("CA file unchanged, skipping transport rotation")
|
||||
metrics.TransportCAReloads.Increment("success", "unchanged")
|
||||
return
|
||||
}
|
||||
|
||||
logger.V(4).Info("CA content changed, updating transport")
|
||||
|
||||
// Load new CA pool
|
||||
newCAs, err := rootCertPool(newCAData)
|
||||
// Return old transport on parse error
|
||||
if err != nil {
|
||||
logger.Error(err, "Failed to parse CA data from file")
|
||||
metrics.TransportCAReloads.Increment("failure", "ca_parse_error")
|
||||
return
|
||||
}
|
||||
newTransport := h.transport.Clone()
|
||||
newTransport.TLSClientConfig.RootCAs = newCAs
|
||||
oldTransport := h.transport
|
||||
h.transport = newTransport
|
||||
// Update our tracking of current CA data
|
||||
h.currentCAData = newCAData
|
||||
|
||||
// Close idle connections on the old transport to encourage migration
|
||||
oldTransport.CloseIdleConnections()
|
||||
|
||||
logger.V(4).Info("Transport updated for CA rotation")
|
||||
metrics.TransportCAReloads.Increment("success", "updated")
|
||||
}
|
||||
|
||||
// newAtomicTransportHolder creates a new holder for CA file reloading scenarios.
|
||||
// The caFile must be specified.
|
||||
// caData may be empty but should correspond to the contents of caFile.
|
||||
// transport must have a TLS config and its root CAs should match caData.
|
||||
func newAtomicTransportHolder(caFile string, caData []byte, transport *http.Transport) *atomicTransportHolder {
|
||||
c := clock.RealClock{}
|
||||
return &atomicTransportHolder{
|
||||
caFile: caFile,
|
||||
currentCAData: caData,
|
||||
clock: c,
|
||||
caRefreshDuration: 5 * time.Minute,
|
||||
transport: transport,
|
||||
transportLastChecked: c.Now(),
|
||||
}
|
||||
}
|
||||
560
transport/ca_rotation_test.go
Normal file
560
transport/ca_rotation_test.go
Normal file
@@ -0,0 +1,560 @@
|
||||
/*
|
||||
Copyright The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
clientgofeaturegate "k8s.io/client-go/features"
|
||||
clientfeaturestesting "k8s.io/client-go/features/testing"
|
||||
"k8s.io/client-go/tools/metrics"
|
||||
"k8s.io/client-go/util/cert"
|
||||
testingclock "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
const (
|
||||
// Use the same rootCACert as transport_test.go
|
||||
testCACert1 = `-----BEGIN CERTIFICATE-----
|
||||
MIIC4DCCAcqgAwIBAgIBATALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
|
||||
MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIxNTczN1oXDTE2MDExNTIxNTcz
|
||||
OFowIzEhMB8GA1UEAwwYMTAuMTMuMTI5LjEwNkAxNDIxMzU5MDU4MIIBIjANBgkq
|
||||
hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAunDRXGwsiYWGFDlWH6kjGun+PshDGeZX
|
||||
xtx9lUnL8pIRWH3wX6f13PO9sktaOWW0T0mlo6k2bMlSLlSZgG9H6og0W6gLS3vq
|
||||
s4VavZ6DbXIwemZG2vbRwsvR+t4G6Nbwelm6F8RFnA1Fwt428pavmNQ/wgYzo+T1
|
||||
1eS+HiN4ACnSoDSx3QRWcgBkB1g6VReofVjx63i0J+w8Q/41L9GUuLqquFxu6ZnH
|
||||
60vTB55lHgFiDLjA1FkEz2dGvGh/wtnFlRvjaPC54JH2K1mPYAUXTreoeJtLJKX0
|
||||
ycoiyB24+zGCniUmgIsmQWRPaOPircexCp1BOeze82BT1LCZNTVaxQIDAQABoyMw
|
||||
ITAOBgNVHQ8BAf8EBAMCAKQwDwYDVR0TAQH/BAUwAwEB/zALBgkqhkiG9w0BAQsD
|
||||
ggEBADMxsUuAFlsYDpF4fRCzXXwrhbtj4oQwcHpbu+rnOPHCZupiafzZpDu+rw4x
|
||||
YGPnCb594bRTQn4pAu3Ac18NbLD5pV3uioAkv8oPkgr8aUhXqiv7KdDiaWm6sbAL
|
||||
EHiXVBBAFvQws10HMqMoKtO8f1XDNAUkWduakR/U6yMgvOPwS7xl0eUTqyRB6zGb
|
||||
K55q2dejiFWaFqB/y78txzvz6UlOZKE44g2JAVoJVM6kGaxh33q8/FmrL4kuN3ut
|
||||
W+MmJCVDvd4eEqPwbp7146ZWTqpIJ8lvA6wuChtqV8lhAPka2hD/LMqY8iXNmfXD
|
||||
uml0obOEy+ON91k+SWTJ3ggmF/U=
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
// A different CA cert for testing rotation (modified version of certData from transport_test.go)
|
||||
testCACert2 = `-----BEGIN CERTIFICATE-----
|
||||
MIIC6jCCAdSgAwIBAgIBCzALBgkqhkiG9w0BAQswIzEhMB8GA1UEAwwYMTAuMTMu
|
||||
MTI5LjEwNkAxNDIxMzU5MDU4MB4XDTE1MDExNTIyMDEzMVoXDTE2MDExNTIyMDEz
|
||||
MlowGzEZMBcGA1UEAxMQb3BlbnNoaWZ0LWNsaWVudDCCASIwDQYJKoZIhvcNAQEB
|
||||
BQADggEPADCCAQoCggEBAKtdhz0+uCLXw5cSYns9rU/XifFSpb/x24WDdrm72S/v
|
||||
b9BPYsAStiP148buylr1SOuNi8sTAZmlVDDIpIVwMLff+o2rKYDicn9fjbrTxTOj
|
||||
lI4pHJBH+JU3AJ0tbajupioh70jwFS0oYpwtneg2zcnE2Z4l6mhrj2okrc5Q1/X2
|
||||
I2HChtIU4JYTisObtin10QKJX01CLfYXJLa8upWzKZ4/GOcHG+eAV3jXWoXidtjb
|
||||
1Usw70amoTZ6mIVCkiu1QwCoa8+ycojGfZhvqMsAp1536ZcCul+Na+AbCv4zKS7F
|
||||
kQQaImVrXdUiFansIoofGlw/JNuoKK6ssVpS5Ic3pgcCAwEAAaM1MDMwDgYDVR0P
|
||||
AQH/BAQDAgCgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMAwGA1UdEwEB/wQCMAAwCwYJ
|
||||
KoZIhvcNAQELA4IBAQCKLREH7bXtXtZ+8vI6cjD7W3QikiArGqbl36bAhhWsJLp/
|
||||
p/ndKz39iFNaiZ3GlwIURWOOKx3y3GA0x9m8FR+Llthf0EQ8sUjnwaknWs0Y6DQ3
|
||||
jjPFZOpV3KPCFrdMJ3++E3MgwFC/Ih/N2ebFX9EcV9Vcc6oVWMdwT0fsrhu683rq
|
||||
6GSR/3iVX1G/pmOiuaR0fNUaCyCfYrnI4zHBDgSfnlm3vIvN2lrsR/DQBakNL8DJ
|
||||
HBgKxMGeUPoneBv+c8DMXIL0EhaFXRlBv9QW45/GiAIOuyFJ0i6hCtGZpJjq4OpQ
|
||||
BRjCI+izPzFTjsxD4aORE+WOkyWFCGPWKfNejfw0
|
||||
-----END CERTIFICATE-----`
|
||||
)
|
||||
|
||||
// writeCAFile writes CA data to a temporary file
|
||||
func writeCAFile(t testing.TB, caData []byte) string {
|
||||
tmpDir := t.TempDir()
|
||||
caFile := filepath.Join(tmpDir, "ca.crt")
|
||||
|
||||
err := os.WriteFile(caFile, caData, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write CA file: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := os.Remove(caFile); err != nil {
|
||||
t.Logf("unexpected error while removing file: %s - %v", caFile, err)
|
||||
}
|
||||
})
|
||||
return caFile
|
||||
}
|
||||
|
||||
// createTestTransport creates a test transport with TLS config
|
||||
func createTestTransport(t testing.TB, caData []byte) *http.Transport {
|
||||
CAs, err := rootCertPool(caData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse CA certificate")
|
||||
}
|
||||
return &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: CAs,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckCAFileAndRotate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCA []byte
|
||||
updateCA []byte
|
||||
caFileOverride string
|
||||
expectRotation bool
|
||||
}{
|
||||
{
|
||||
name: "no change",
|
||||
setupCA: []byte(testCACert1),
|
||||
updateCA: []byte(testCACert1), // Same CA
|
||||
expectRotation: false,
|
||||
},
|
||||
{
|
||||
name: "CA changed",
|
||||
setupCA: []byte(testCACert1),
|
||||
updateCA: []byte(testCACert2), // Different CA
|
||||
expectRotation: true,
|
||||
},
|
||||
{
|
||||
name: "CA changed to invalid",
|
||||
setupCA: []byte(testCACert1),
|
||||
updateCA: []byte("panda"), // invalid CA
|
||||
expectRotation: false,
|
||||
},
|
||||
{
|
||||
name: "file error",
|
||||
setupCA: []byte(testCACert1),
|
||||
caFileOverride: "/nonexistent/ca.crt", // Non-existent file
|
||||
expectRotation: false,
|
||||
},
|
||||
{
|
||||
name: "empty file content",
|
||||
setupCA: []byte(testCACert1),
|
||||
updateCA: []byte{}, // Empty file
|
||||
expectRotation: false,
|
||||
},
|
||||
{
|
||||
name: "initially empty CA file updated to valid CA",
|
||||
setupCA: []byte{}, // Starts with empty CA
|
||||
updateCA: []byte(testCACert1), // Populated with a valid CA
|
||||
expectRotation: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
caFile := writeCAFile(t, tt.setupCA)
|
||||
if len(tt.caFileOverride) > 0 {
|
||||
caFile = tt.caFileOverride
|
||||
}
|
||||
|
||||
transport := createTestTransport(t, tt.setupCA)
|
||||
setupRoots := transport.TLSClientConfig.RootCAs.Clone()
|
||||
|
||||
expectedRoots := setupRoots
|
||||
if tt.expectRotation {
|
||||
var err error
|
||||
expectedRoots, err = rootCertPool(tt.updateCA)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
clock := testingclock.NewFakeClock(time.Now())
|
||||
holder := newAtomicTransportHolder(caFile, tt.setupCA, transport)
|
||||
holder.clock = clock
|
||||
holder.transportLastChecked = clock.Now()
|
||||
|
||||
if tt.updateCA != nil {
|
||||
// Update the file with new CA content
|
||||
err := os.WriteFile(caFile, tt.updateCA, 0644)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to update CA data with file address: %s", caFile)
|
||||
}
|
||||
}
|
||||
|
||||
clock.Step(holder.caRefreshDuration)
|
||||
|
||||
// Check CA file rotation
|
||||
newTransport := holder.getTransport(t.Context())
|
||||
newRoots := newTransport.TLSClientConfig.RootCAs
|
||||
|
||||
if newRoots == nil || !expectedRoots.Equal(newRoots) {
|
||||
t.Error("new roots did not match expected roots")
|
||||
}
|
||||
|
||||
transportRotated := newTransport != transport
|
||||
if tt.expectRotation != transportRotated {
|
||||
t.Error("transport rotation did not match")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateServerCertAndCA(t testing.TB) (servingCertPEM, servingKeyPEM, caCertPEM []byte) {
|
||||
t.Helper()
|
||||
certPEM, keyPEM, err := cert.GenerateSelfSignedCertKey("127.0.0.1", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate server cert: %v", err)
|
||||
}
|
||||
certs, err := cert.ParseCertsPEM(certPEM)
|
||||
if err != nil || len(certs) < 2 {
|
||||
t.Fatal("Expected cert chain with [leaf, CA]")
|
||||
}
|
||||
caPEM, err := cert.EncodeCertificates(certs[len(certs)-1])
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encode CA cert: %v", err)
|
||||
}
|
||||
return certPEM, keyPEM, caPEM
|
||||
}
|
||||
|
||||
// TestCARotationConnectionBehavior tests end-to-end CA rotation:
|
||||
// 1. Client trusts server CA via a CA file on disk
|
||||
// 2. Server rotates to a new CA + serving cert
|
||||
// 3. Client fails (doesn't trust new CA yet)
|
||||
// 4. CA file updated, transport reloads, client reconnects
|
||||
func TestCARotationConnectionBehavior(t *testing.T) {
|
||||
t.Log("Testing CA Rotation Connection Behavior")
|
||||
|
||||
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
|
||||
|
||||
// Generate initial server v1 cert and CA
|
||||
servingCertPEM1, servingKeyPEM1, caCertPEM1 := generateServerCertAndCA(t)
|
||||
|
||||
// Start server v1 (no client cert required - test focuses on server CA rotation)
|
||||
srv1 := newTestServer(t, servingCertPEM1, servingKeyPEM1)
|
||||
srv1.StartTLS()
|
||||
defer srv1.Close()
|
||||
|
||||
// Set up the client
|
||||
clientCAFile := writeCAFile(t, caCertPEM1)
|
||||
config := &Config{
|
||||
TLS: TLSConfig{
|
||||
CAFile: clientCAFile,
|
||||
},
|
||||
}
|
||||
|
||||
transport, err := New(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create transport: %v", err)
|
||||
}
|
||||
transport.(*atomicTransportHolder).caRefreshDuration = 500 * time.Millisecond
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Initial connection must succeed
|
||||
t.Log("Making initial request to server v1, expecting success...")
|
||||
resp, err := client.Get(srv1.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call the server v1: %v", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
t.Fatal("Failed to close the response.")
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatal("Failed to call the server successfully.")
|
||||
}
|
||||
|
||||
t.Log("Initial connection successful.")
|
||||
|
||||
// Rotate: new CA, new server cert, same address
|
||||
t.Log("Stopping server v1 and starting server v2 with new CA...")
|
||||
srv1Addr := srv1.Listener.Addr().String()
|
||||
srv1.Close()
|
||||
|
||||
servingCertPEM2, servingKeyPEM2, caCertPEM2 := generateServerCertAndCA(t)
|
||||
srv2 := newTestServer(t, servingCertPEM2, servingKeyPEM2)
|
||||
l, err := net.Listen("tcp", srv1Addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to re-claim the same server address: %v", err)
|
||||
}
|
||||
|
||||
srv2.Listener = l
|
||||
srv2.StartTLS()
|
||||
defer srv2.Close()
|
||||
|
||||
// Must fail - client still trusts old CA
|
||||
t.Log("Making request to server v2, expecting failure...")
|
||||
_, err = client.Get(srv2.URL)
|
||||
if err == nil {
|
||||
t.Fatal("The request should fail.")
|
||||
}
|
||||
t.Log("Request failed as expected.")
|
||||
|
||||
// Update CA file to trust new CA
|
||||
t.Log("Updating client CA file on disk to trust new CA...")
|
||||
if err := os.WriteFile(clientCAFile, caCertPEM2, 0644); err != nil {
|
||||
t.Fatalf("Failed to update CA file: %v", err)
|
||||
}
|
||||
|
||||
// Poll until transport reloads
|
||||
t.Log("Polling server v2 until the client's transport reloads the new CA...")
|
||||
var lastPollErr error
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer ctxCancel()
|
||||
|
||||
pollErr := wait.PollUntilContextCancel(ctx, 500*time.Millisecond, true, func(ctx context.Context) (bool, error) {
|
||||
resp, err := client.Get(srv2.URL)
|
||||
if err != nil {
|
||||
lastPollErr = err
|
||||
t.Log("Client failed to connect before the root CAs are updated, will retry...")
|
||||
return false, nil // Error is expected, continue polling
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
t.Fatal("Failed to close the response.")
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return true, nil // Success! Stop polling.
|
||||
}
|
||||
return false, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
})
|
||||
if pollErr != nil {
|
||||
t.Fatalf("Client failed to reconnect after CA rotation. Last error: %v. Test error: %v", lastPollErr, pollErr)
|
||||
}
|
||||
|
||||
t.Log("Success! Client reconnected after CA was refreshed.")
|
||||
}
|
||||
|
||||
func TestCARotationConnectionBehavior_Disabled(t *testing.T) {
|
||||
t.Log("Testing CA Rotation Connection Behavior (Feature Disabled)")
|
||||
|
||||
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false)
|
||||
|
||||
// Generate initial server v1 cert and CA
|
||||
servingCertPEM1, servingKeyPEM1, caCertPEM1 := generateServerCertAndCA(t)
|
||||
|
||||
// Start server v1 (no client cert required - test focuses on server CA rotation)
|
||||
srv1 := newTestServer(t, servingCertPEM1, servingKeyPEM1)
|
||||
srv1.StartTLS()
|
||||
defer srv1.Close()
|
||||
|
||||
// Set up the client
|
||||
clientCAFile := writeCAFile(t, caCertPEM1)
|
||||
config := &Config{
|
||||
TLS: TLSConfig{
|
||||
CAFile: clientCAFile,
|
||||
},
|
||||
}
|
||||
|
||||
transport, err := New(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create transport: %v", err)
|
||||
}
|
||||
if _, ok := transport.(*atomicTransportHolder); ok {
|
||||
t.Fatal("Expected plain transport when the feature gate is disabled, got atomicTransportHolder")
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Initial connection must succeed
|
||||
t.Log("Making initial request to server v1, expecting success...")
|
||||
resp, err := client.Get(srv1.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call the server v1: %v", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
t.Fatal("Failed to close the response.")
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatal("Failed to call the server successfully.")
|
||||
}
|
||||
|
||||
t.Log("Initial connection successful.")
|
||||
|
||||
// Rotate: new CA, new server cert, same address
|
||||
t.Log("Stopping server v1 and starting server v2 with new CA...")
|
||||
srv1Addr := srv1.Listener.Addr().String()
|
||||
srv1.Close()
|
||||
|
||||
servingCertPEM2, servingKeyPEM2, caCertPEM2 := generateServerCertAndCA(t)
|
||||
srv2 := newTestServer(t, servingCertPEM2, servingKeyPEM2)
|
||||
l, err := net.Listen("tcp", srv1Addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to re-claim the same server address: %v", err)
|
||||
}
|
||||
|
||||
srv2.Listener = l
|
||||
srv2.StartTLS()
|
||||
defer srv2.Close()
|
||||
|
||||
// Must fail - client still trusts old CA
|
||||
t.Log("Making request to server v2, expecting failure...")
|
||||
_, err = client.Get(srv2.URL)
|
||||
if err == nil {
|
||||
t.Fatal("The request should fail.")
|
||||
}
|
||||
t.Log("Request failed as expected.")
|
||||
|
||||
// Update CA file to trust new CA
|
||||
t.Log("Updating client CA file on disk to trust new CA...")
|
||||
if err := os.WriteFile(clientCAFile, caCertPEM2, 0644); err != nil {
|
||||
t.Fatalf("Failed to update CA file: %v", err)
|
||||
}
|
||||
|
||||
// Poll to ensure transport DOES NOT reload
|
||||
t.Log("Polling server v2 to verify the client DOES NOT reconnect...")
|
||||
var lastPollErr error
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer ctxCancel()
|
||||
|
||||
pollErr := wait.PollUntilContextCancel(ctx, 500*time.Millisecond, true, func(ctx context.Context) (bool, error) {
|
||||
resp, err := client.Get(srv2.URL)
|
||||
if err != nil {
|
||||
lastPollErr = err
|
||||
t.Log("Client failed to connect (expected because feature is disabled)...")
|
||||
return false, nil // Keep polling until the timeout is reached
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
t.Fatal("Failed to close the response.")
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return true, nil // Success! But this means the test failed.
|
||||
}
|
||||
return false, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
})
|
||||
|
||||
// If pollErr is nil, it means the connection succeeded, which is wrong for this test.
|
||||
if pollErr == nil {
|
||||
t.Fatalf("Client unexpectedly reconnected after CA rotation! The feature gate is disabled, so this should have failed.")
|
||||
}
|
||||
|
||||
t.Logf("Success! Client permanently failed to reconnect as expected. Last error: %v", lastPollErr)
|
||||
}
|
||||
|
||||
// testCAReloadsMetric is a fake metric recorder that records calls to Increment.
|
||||
type testCAReloadsMetric struct {
|
||||
calls []caReloadCall
|
||||
}
|
||||
|
||||
type caReloadCall struct {
|
||||
result, reason string
|
||||
}
|
||||
|
||||
func (m *testCAReloadsMetric) Increment(result, reason string) {
|
||||
m.calls = append(m.calls, caReloadCall{result, reason})
|
||||
}
|
||||
|
||||
// TestCARotationMetricsEmitted verifies that ca_rotation.go emits the correct
|
||||
// metrics during actual CA reload operations.
|
||||
func TestCARotationMetricsEmitted(t *testing.T) {
|
||||
fakeMetricRecorder := &testCAReloadsMetric{}
|
||||
origMetric := metrics.TransportCAReloads
|
||||
metrics.TransportCAReloads = fakeMetricRecorder
|
||||
t.Cleanup(func() { metrics.TransportCAReloads = origMetric })
|
||||
|
||||
caData := []byte(testCACert1)
|
||||
caFile := writeCAFile(t, caData)
|
||||
transport := createTestTransport(t, caData)
|
||||
|
||||
clock := testingclock.NewFakeClock(time.Now())
|
||||
holder := newAtomicTransportHolder(caFile, caData, transport)
|
||||
holder.clock = clock
|
||||
holder.transportLastChecked = clock.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() error
|
||||
wantResult string
|
||||
wantReason string
|
||||
}{
|
||||
{
|
||||
name: "unchanged CA",
|
||||
setup: func() error { return nil },
|
||||
wantResult: "success",
|
||||
wantReason: "unchanged",
|
||||
},
|
||||
{
|
||||
name: "updated CA",
|
||||
setup: func() error {
|
||||
return os.WriteFile(caFile, []byte(testCACert2), 0644)
|
||||
},
|
||||
wantResult: "success",
|
||||
wantReason: "updated",
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
setup: func() error {
|
||||
return os.WriteFile(caFile, []byte{}, 0644)
|
||||
},
|
||||
wantResult: "failure",
|
||||
wantReason: "empty",
|
||||
},
|
||||
{
|
||||
name: "invalid CA data",
|
||||
setup: func() error {
|
||||
return os.WriteFile(caFile, []byte("not-a-cert"), 0644)
|
||||
},
|
||||
wantResult: "failure",
|
||||
wantReason: "ca_parse_error",
|
||||
},
|
||||
{
|
||||
name: "read error",
|
||||
setup: func() error {
|
||||
return os.Remove(caFile)
|
||||
},
|
||||
wantResult: "failure",
|
||||
wantReason: "read_error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fakeMetricRecorder.calls = nil
|
||||
err := tt.setup()
|
||||
if err != nil {
|
||||
t.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
clock.Step(holder.caRefreshDuration)
|
||||
holder.getTransport(context.Background())
|
||||
|
||||
if len(fakeMetricRecorder.calls) != 1 {
|
||||
t.Fatalf("Expected 1 metric call, got %d", len(fakeMetricRecorder.calls))
|
||||
}
|
||||
if fakeMetricRecorder.calls[0].result != tt.wantResult || fakeMetricRecorder.calls[0].reason != tt.wantReason {
|
||||
t.Errorf("Got metric(%s, %s), want (%s, %s)",
|
||||
fakeMetricRecorder.calls[0].result, fakeMetricRecorder.calls[0].reason, tt.wantResult, tt.wantReason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// helper to create a simple, non-blocking test server with a given certificate.
|
||||
func newTestServer(t *testing.T, certPEM, keyPEM []byte) *httptest.Server {
|
||||
t.Helper()
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := fmt.Fprint(w, "ok"); err != nil {
|
||||
t.Fatal("Failed to write to the response.")
|
||||
}
|
||||
}))
|
||||
// Configure server TLS
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server cert: %v", err)
|
||||
}
|
||||
|
||||
server.TLS = &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
}
|
||||
return server
|
||||
}
|
||||
@@ -36,7 +36,7 @@ import (
|
||||
// the config has no custom TLS options, http.DefaultTransport is returned.
|
||||
type tlsTransportCache struct {
|
||||
mu sync.Mutex
|
||||
transports map[tlsCacheKey]*http.Transport
|
||||
transports map[tlsCacheKey]http.RoundTripper
|
||||
}
|
||||
|
||||
// DialerStopCh is stop channel that is passed down to dynamic cert dialer.
|
||||
@@ -46,11 +46,14 @@ var DialerStopCh = wait.NeverStop
|
||||
|
||||
const idleConnsPerHost = 25
|
||||
|
||||
var tlsCache = &tlsTransportCache{transports: make(map[tlsCacheKey]*http.Transport)}
|
||||
var tlsCache = &tlsTransportCache{
|
||||
transports: make(map[tlsCacheKey]http.RoundTripper),
|
||||
}
|
||||
|
||||
type tlsCacheKey struct {
|
||||
insecure bool
|
||||
caData string
|
||||
caFile string
|
||||
certData string
|
||||
keyData string `datapolicy:"security-key"`
|
||||
certFile string
|
||||
@@ -68,8 +71,8 @@ func (t tlsCacheKey) String() string {
|
||||
if len(t.keyData) > 0 {
|
||||
keyText = "<redacted>"
|
||||
}
|
||||
return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
|
||||
t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
|
||||
return fmt.Sprintf("insecure:%v, caData:%#v, caFile:%s, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p",
|
||||
t.insecure, t.caData, t.caFile, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial)
|
||||
}
|
||||
|
||||
func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
|
||||
@@ -131,7 +134,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
|
||||
proxy = config.Proxy
|
||||
}
|
||||
|
||||
transport := utilnet.SetTransportDefaults(&http.Transport{
|
||||
httpTransport := utilnet.SetTransportDefaults(&http.Transport{
|
||||
Proxy: proxy,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
@@ -139,6 +142,11 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
|
||||
DialContext: dial,
|
||||
DisableCompression: config.DisableCompression,
|
||||
})
|
||||
var transport http.RoundTripper = httpTransport
|
||||
|
||||
if config.TLS.ReloadCAFiles && tlsConfig != nil && tlsConfig.RootCAs != nil && len(config.TLS.CAFile) > 0 {
|
||||
transport = newAtomicTransportHolder(config.TLS.CAFile, config.TLS.CAData, httpTransport)
|
||||
}
|
||||
|
||||
if canCache {
|
||||
// Cache a single transport for these options
|
||||
@@ -162,7 +170,6 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
|
||||
|
||||
k := tlsCacheKey{
|
||||
insecure: c.TLS.Insecure,
|
||||
caData: string(c.TLS.CAData),
|
||||
serverName: c.TLS.ServerName,
|
||||
nextProtos: strings.Join(c.TLS.NextProtos, ","),
|
||||
disableCompression: c.DisableCompression,
|
||||
@@ -178,5 +185,14 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) {
|
||||
k.keyData = string(c.TLS.KeyData)
|
||||
}
|
||||
|
||||
if c.TLS.ReloadCAFiles {
|
||||
// When reloading CA files, include CA file path in cache key instead of CA data
|
||||
// This allows the CA to be reloaded from disk on each transport creation
|
||||
k.caFile = c.TLS.CAFile
|
||||
} else {
|
||||
// When not reloading, cache the CA data directly
|
||||
k.caData = string(c.TLS.CAData)
|
||||
}
|
||||
|
||||
return k, true, nil
|
||||
}
|
||||
|
||||
@@ -19,14 +19,22 @@ package transport
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
clientgofeaturegate "k8s.io/client-go/features"
|
||||
clientfeaturestesting "k8s.io/client-go/features/testing"
|
||||
)
|
||||
|
||||
func TestTLSConfigKey(t *testing.T) {
|
||||
|
||||
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
|
||||
// Make sure config fields that don't affect the tls config don't affect the cache key
|
||||
identicalConfigurations := map[string]*Config{
|
||||
"empty": {},
|
||||
@@ -70,14 +78,17 @@ func TestTLSConfigKey(t *testing.T) {
|
||||
// Make sure config fields that affect the tls config affect the cache key
|
||||
dialer := net.Dialer{}
|
||||
getCert := &GetCertHolder{GetCert: func() (*tls.Certificate, error) { return nil, nil }}
|
||||
caFile := writeCAFile(t, []byte(testCACert1))
|
||||
uniqueConfigurations := map[string]*Config{
|
||||
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
|
||||
"no tls": {},
|
||||
"dialer": {DialHolder: &DialHolder{Dial: dialer.DialContext}},
|
||||
"dialer2": {DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
|
||||
"insecure": {TLS: TLSConfig{Insecure: true}},
|
||||
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
|
||||
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
|
||||
"proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }},
|
||||
"no tls": {},
|
||||
"dialer": {DialHolder: &DialHolder{Dial: dialer.DialContext}},
|
||||
"dialer2": {DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}},
|
||||
"insecure": {TLS: TLSConfig{Insecure: true}},
|
||||
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
|
||||
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},
|
||||
"with only ca file": {TLS: TLSConfig{CAFile: caFile}},
|
||||
"with both ca file and ca data": {TLS: TLSConfig{CAFile: caFile, CAData: []byte(testCACert1)}},
|
||||
"cert 1, key 1": {
|
||||
TLS: TLSConfig{
|
||||
CertData: []byte{1},
|
||||
@@ -188,3 +199,228 @@ func TestTLSConfigKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSConfigKeyCARotationDisabled(t *testing.T) {
|
||||
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false)
|
||||
|
||||
caFile := writeCAFile(t, []byte(testCACert1))
|
||||
|
||||
// When feature is disabled, CAFile-only config resolves CAData via
|
||||
// loadTLSFiles, so two configs with the same file content get the same key.
|
||||
config1 := &Config{TLS: TLSConfig{CAFile: caFile}}
|
||||
config2 := &Config{TLS: TLSConfig{CAData: []byte(testCACert1)}}
|
||||
|
||||
if err := loadTLSFiles(config1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := loadTLSFiles(config2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key1, canCache1, err := tlsConfigKey(config1)
|
||||
if err != nil || !canCache1 {
|
||||
t.Fatalf("unexpected: err=%v, canCache=%v", err, canCache1)
|
||||
}
|
||||
|
||||
key2, canCache2, err := tlsConfigKey(config2)
|
||||
if err != nil || !canCache2 {
|
||||
t.Fatalf("unexpected: err=%v, canCache=%v", err, canCache2)
|
||||
}
|
||||
|
||||
if key1 != key2 {
|
||||
t.Error("Expected same cache key when feature is disabled (CAFile resolved to CAData)")
|
||||
}
|
||||
if config1.TLS.ReloadCAFiles {
|
||||
t.Error("Expected ReloadCAFiles=false when feature gate is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTLSTransportCacheCARotation tests transport cache behavior with CA rotation
|
||||
func TestTLSTransportCacheCARotation(t *testing.T) {
|
||||
|
||||
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
|
||||
caFile := writeCAFile(t, []byte(testCACert1))
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *Config
|
||||
expectWrapper bool
|
||||
expectCacheable bool
|
||||
}{
|
||||
{
|
||||
name: "CA rotation should be enabled when only the CAFile is set",
|
||||
config: &Config{
|
||||
TLS: TLSConfig{
|
||||
CAFile: caFile,
|
||||
},
|
||||
},
|
||||
expectWrapper: true,
|
||||
expectCacheable: true,
|
||||
},
|
||||
{
|
||||
name: "CA rotation should be disabled when both CAFile and CAData are set",
|
||||
config: &Config{
|
||||
TLS: TLSConfig{
|
||||
CAFile: caFile,
|
||||
CAData: []byte(testCACert1),
|
||||
},
|
||||
},
|
||||
expectWrapper: false,
|
||||
expectCacheable: true,
|
||||
},
|
||||
{
|
||||
name: "CA rotation should be disabled when only the CAData is set",
|
||||
config: &Config{
|
||||
TLS: TLSConfig{
|
||||
CAData: []byte(testCACert1),
|
||||
},
|
||||
},
|
||||
expectWrapper: false,
|
||||
expectCacheable: true,
|
||||
},
|
||||
{
|
||||
name: "no TLS config",
|
||||
config: &Config{
|
||||
TLS: TLSConfig{},
|
||||
},
|
||||
expectWrapper: false,
|
||||
expectCacheable: false, // No TLS config means default transport
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create new cache for testing
|
||||
tlsCaches := &tlsTransportCache{
|
||||
transports: make(map[tlsCacheKey]http.RoundTripper),
|
||||
}
|
||||
|
||||
rt, err := tlsCaches.get(tc.config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !tc.expectCacheable {
|
||||
// Should return default transport
|
||||
if rt != http.DefaultTransport {
|
||||
t.Errorf("Expected default transport, got %T", rt)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expectWrapper {
|
||||
// Should be wrapped in atomicTransportHolder
|
||||
if _, ok := rt.(*atomicTransportHolder); !ok {
|
||||
t.Errorf("Expected atomicTransportHolder, got %T", rt)
|
||||
}
|
||||
if !tc.config.TLS.ReloadCAFiles {
|
||||
t.Errorf("Expected ReloadCAFiles to be true, got %v", tc.config.TLS.ReloadCAFiles)
|
||||
}
|
||||
} else {
|
||||
// Should be a regular http.Transport
|
||||
if _, ok := rt.(*http.Transport); !ok {
|
||||
t.Errorf("Expected *http.Transport, got %T", rt)
|
||||
}
|
||||
}
|
||||
|
||||
// Test caching: second call should return the same instance
|
||||
rt2, err := tlsCaches.get(tc.config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error on second call: %v", err)
|
||||
}
|
||||
|
||||
if rt != rt2 {
|
||||
t.Error("Expected same transport instance from cache")
|
||||
}
|
||||
|
||||
// Verify cache size
|
||||
tlsCaches.mu.Lock()
|
||||
cacheSize := len(tlsCaches.transports)
|
||||
tlsCaches.mu.Unlock()
|
||||
|
||||
expectedCacheSize := 1
|
||||
if !tc.expectCacheable {
|
||||
expectedCacheSize = 0
|
||||
}
|
||||
|
||||
if cacheSize != expectedCacheSize {
|
||||
t.Errorf("Expected %d transports in cache, got %d", expectedCacheSize, cacheSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSTransportCacheCARotationDisabled(t *testing.T) {
|
||||
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, false)
|
||||
|
||||
caFile := writeCAFile(t, []byte(testCACert1))
|
||||
cache := &tlsTransportCache{transports: make(map[tlsCacheKey]http.RoundTripper)}
|
||||
|
||||
rt, err := cache.get(&Config{TLS: TLSConfig{CAFile: caFile}})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := rt.(*atomicTransportHolder); ok {
|
||||
t.Error("Expected plain *http.Transport when feature gate is disabled, got atomicTransportHolder")
|
||||
}
|
||||
if _, ok := rt.(*http.Transport); !ok {
|
||||
t.Errorf("Expected *http.Transport, got %T", rt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyCAFileRotationLifecycle(t *testing.T) {
|
||||
// Enable the feature gate for the duration of the test
|
||||
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, true)
|
||||
|
||||
// Create a valid file path, but with empty cert data
|
||||
emptyFile := writeCAFile(t, []byte{})
|
||||
|
||||
config := &Config{
|
||||
TLS: TLSConfig{
|
||||
CAFile: emptyFile,
|
||||
},
|
||||
}
|
||||
|
||||
tlsCaches := &tlsTransportCache{
|
||||
transports: make(map[tlsCacheKey]http.RoundTripper),
|
||||
}
|
||||
|
||||
rt, err := tlsCaches.get(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error getting transport: %v", err)
|
||||
}
|
||||
|
||||
// Verify newAtomicTransportHolder is successfully generated
|
||||
holder, ok := rt.(*atomicTransportHolder)
|
||||
if !ok {
|
||||
t.Fatalf("Expected atomicTransportHolder, got %T", rt)
|
||||
}
|
||||
|
||||
// Verify the initial state: RootCAs should be non-nil but empty
|
||||
initialTransport := holder.getTransport(context.Background())
|
||||
|
||||
if initialTransport.TLSClientConfig == nil || initialTransport.TLSClientConfig.RootCAs == nil {
|
||||
t.Fatal("Expected RootCAs to be non-nil for an empty CA file (should be an empty CertPool)")
|
||||
}
|
||||
emptyPool := x509.NewCertPool()
|
||||
if !initialTransport.TLSClientConfig.RootCAs.Equal(emptyPool) {
|
||||
t.Fatal("Expected initially empty RootCAs")
|
||||
}
|
||||
|
||||
// Write valid cert data into the CA file
|
||||
if err := os.WriteFile(emptyFile, []byte(testCACert1), 0644); err != nil {
|
||||
t.Fatalf("Failed to write to CA file: %v", err)
|
||||
}
|
||||
|
||||
holder.mu.Lock()
|
||||
// Set last checked time far in the past to force a refresh on next getTransport call
|
||||
holder.transportLastChecked = time.Now().Add(-time.Hour)
|
||||
holder.mu.Unlock()
|
||||
|
||||
refreshedTransport := holder.getTransport(context.Background())
|
||||
|
||||
// Verify the refresh succeeded and the cert pool is now populated
|
||||
if refreshedTransport.TLSClientConfig.RootCAs.Equal(emptyPool) {
|
||||
t.Fatal("Expected RootCAs to be populated after writing valid cert data and refreshing")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,7 +134,8 @@ type TLSConfig struct {
|
||||
CAFile string // Path of the PEM-encoded server trusted root certificates.
|
||||
CertFile string // Path of the PEM-encoded client certificate.
|
||||
KeyFile string // Path of the PEM-encoded client key.
|
||||
ReloadTLSFiles bool // Set to indicate that the original config provided files, and that they should be reloaded
|
||||
ReloadTLSFiles bool // Set to indicate that the original config provided files, and that they should be reloaded.
|
||||
ReloadCAFiles bool // Set to indicate that CA files should be reloaded from disk.
|
||||
|
||||
Insecure bool // Server should be accessed without verifying the certificate. For testing only.
|
||||
ServerName string // Override for the server name passed to the server for SNI and used to verify certificates.
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"time"
|
||||
|
||||
utilnet "k8s.io/apimachinery/pkg/util/net"
|
||||
clientgofeaturegate "k8s.io/client-go/features"
|
||||
"k8s.io/klog/v2"
|
||||
)
|
||||
|
||||
@@ -211,17 +212,26 @@ func TLSConfigFor(c *Config) (*tls.Config, error) {
|
||||
// KeyData, and CAFile fields, or returns an error. If no error is returned, all three fields are
|
||||
// either populated or were empty to start.
|
||||
func loadTLSFiles(c *Config) error {
|
||||
// Check that we are purely loading CA from file
|
||||
if clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowCARotation) {
|
||||
if len(c.TLS.CAFile) > 0 && len(c.TLS.CAData) == 0 {
|
||||
c.TLS.ReloadCAFiles = true
|
||||
}
|
||||
} else if c.TLS.ReloadCAFiles {
|
||||
return fmt.Errorf("ReloadCAFiles=true requires ClientsAllowCARotation to be enabled")
|
||||
}
|
||||
|
||||
// Check that we are purely loading certs and keys from files
|
||||
if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 {
|
||||
c.TLS.ReloadTLSFiles = true
|
||||
}
|
||||
|
||||
var err error
|
||||
c.TLS.CAData, err = dataFromSliceOrFile(c.TLS.CAData, c.TLS.CAFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check that we are purely loading from files
|
||||
if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 {
|
||||
c.TLS.ReloadTLSFiles = true
|
||||
}
|
||||
|
||||
c.TLS.CertData, err = dataFromSliceOrFile(c.TLS.CertData, c.TLS.CertFile)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -254,6 +264,11 @@ func rootCertPool(caData []byte) (*x509.CertPool, error) {
|
||||
// code for a look at the platform specific insanity), so we'll use the fact that RootCAs == nil gives us the system values
|
||||
// It doesn't allow trusting either/or, but hopefully that won't be an issue
|
||||
if len(caData) == 0 {
|
||||
// When the ClientsAllowCARotation feature gate is enabled, it returns an empty but non-nil pool.
|
||||
// This ensures we don't fall back to system roots when a user explicitly points CAFile to a zero-byte file.
|
||||
if clientgofeaturegate.FeatureGates().Enabled(clientgofeaturegate.ClientsAllowCARotation) {
|
||||
return x509.NewCertPool(), nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,11 +19,15 @@ package transport
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
clientgofeaturegate "k8s.io/client-go/features"
|
||||
clientfeaturestesting "k8s.io/client-go/features/testing"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -356,6 +360,16 @@ func TestNew(t *testing.T) {
|
||||
}
|
||||
for k, testCase := range testCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
// The Close method of httptest Server mutates the
|
||||
// `http.DefaultTransport` object, the 'TLSClientConfig'
|
||||
// field mutates from nil to a non nil instance. This introduces flake
|
||||
// and data race when running tests under transport package in parallel.
|
||||
// To work around it we reset the TLSClientConfig field.
|
||||
//
|
||||
// See: https://github.com/golang/go/issues/65796
|
||||
if testCase.Default {
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = nil
|
||||
}
|
||||
rt, err := New(testCase.Config)
|
||||
switch {
|
||||
case testCase.Err && err == nil:
|
||||
@@ -556,3 +570,51 @@ func Test_contextCanceller_RoundTrip(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRootCertPoolEmptyData(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
featureGateEnabled bool
|
||||
expectNilPool bool
|
||||
}{
|
||||
{
|
||||
name: "feature gate disabled returns nil (system roots)",
|
||||
featureGateEnabled: false,
|
||||
expectNilPool: true,
|
||||
},
|
||||
{
|
||||
name: "feature gate enabled returns empty pool (trust nothing)",
|
||||
featureGateEnabled: true,
|
||||
expectNilPool: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Set the feature gate according to the test case
|
||||
clientfeaturestesting.SetFeatureDuringTest(t, clientgofeaturegate.ClientsAllowCARotation, tc.featureGateEnabled)
|
||||
|
||||
// Call the function with empty caData
|
||||
pool, err := rootCertPool([]byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if tc.expectNilPool {
|
||||
if pool != nil {
|
||||
t.Fatalf("Expected pool to be nil when feature gate is disabled, but got a populated pool")
|
||||
}
|
||||
} else {
|
||||
if pool == nil {
|
||||
t.Fatalf("Expected pool to be non-nil (empty pool) when feature gate is enabled, but got nil")
|
||||
}
|
||||
|
||||
// Verify it is truly an empty pool
|
||||
emptyPool := x509.NewCertPool()
|
||||
if !pool.Equal(emptyPool) {
|
||||
t.Fatalf("Expected the returned pool to be completely empty")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user