diff --git a/rest/request.go b/rest/request.go index 4c04eaad..1acd189e 100644 --- a/rest/request.go +++ b/rest/request.go @@ -868,7 +868,7 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp return err } // For connection errors and apiserver shutdown errors retry. - if net.IsConnectionReset(err) { + if net.IsConnectionReset(err) || net.IsProbableEOF(err) { // For the purpose of retry, we set the artificial "retry-after" response. // TODO: Should we clean the original response if it exists? resp = &http.Response{ diff --git a/transport/cache.go b/transport/cache.go index 980d36ae..36d6500f 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -25,6 +25,7 @@ import ( "time" utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/apimachinery/pkg/util/wait" ) // TlsTransportCache caches TLS http.RoundTrippers different configurations. The @@ -44,6 +45,8 @@ type tlsCacheKey struct { caData string certData string keyData string + certFile string + keyFile string getCert string serverName string nextProtos string @@ -91,6 +94,16 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { KeepAlive: 30 * time.Second, }).DialContext } + + // If we use are reloading files, we need to handle certificate rotation properly + // TODO(jackkleeman): We can also add rotation here when config.HasCertCallback() is true + if config.TLS.ReloadTLSFiles { + dynamicCertDialer := certRotatingDialer(tlsConfig.GetClientCertificate, dial) + tlsConfig.GetClientCertificate = dynamicCertDialer.GetClientCertificate + dial = dynamicCertDialer.connDialer.DialContext + go dynamicCertDialer.Run(wait.NeverStop) + } + // Cache a single transport for these options c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -109,15 +122,23 @@ func tlsConfigKey(c *Config) (tlsCacheKey, error) { if err := loadTLSFiles(c); err != nil { return tlsCacheKey{}, err } - return tlsCacheKey{ + k := tlsCacheKey{ insecure: c.TLS.Insecure, caData: string(c.TLS.CAData), - certData: string(c.TLS.CertData), - keyData: string(c.TLS.KeyData), getCert: fmt.Sprintf("%p", c.TLS.GetCert), serverName: c.TLS.ServerName, nextProtos: strings.Join(c.TLS.NextProtos, ","), dial: fmt.Sprintf("%p", c.Dial), disableCompression: c.DisableCompression, - }, nil + } + + if c.TLS.ReloadTLSFiles { + k.certFile = c.TLS.CertFile + k.keyFile = c.TLS.KeyFile + } else { + k.certData = string(c.TLS.CertData) + k.keyData = string(c.TLS.KeyData) + } + + return k, nil } diff --git a/transport/cert_rotation.go b/transport/cert_rotation.go new file mode 100644 index 00000000..0671b29b --- /dev/null +++ b/transport/cert_rotation.go @@ -0,0 +1,176 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package transport + +import ( + "bytes" + "crypto/tls" + "fmt" + "reflect" + "sync" + "time" + + utilnet "k8s.io/apimachinery/pkg/util/net" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/connrotation" + "k8s.io/client-go/util/workqueue" + "k8s.io/klog" +) + +const workItemKey = "key" + +// CertCallbackRefreshDuration is exposed so that integration tests can crank up the reload speed. +var CertCallbackRefreshDuration = 5 * time.Minute + +type reloadFunc func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + +type dynamicClientCert struct { + clientCert *tls.Certificate + certMtx sync.RWMutex + + reload reloadFunc + connDialer *connrotation.Dialer + + // queue only ever has one item, but it has nice error handling backoff/retry semantics + queue workqueue.RateLimitingInterface +} + +func certRotatingDialer(reload reloadFunc, dial utilnet.DialFunc) *dynamicClientCert { + d := &dynamicClientCert{ + reload: reload, + connDialer: connrotation.NewDialer(connrotation.DialFunc(dial)), + queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "DynamicClientCertificate"), + } + + return d +} + +// loadClientCert calls the callback and rotates connections if needed +func (c *dynamicClientCert) loadClientCert() (*tls.Certificate, error) { + cert, err := c.reload(nil) + if err != nil { + return nil, err + } + + // check to see if we have a change. If the values are the same, do nothing. + c.certMtx.RLock() + haveCert := c.clientCert != nil + if certsEqual(c.clientCert, cert) { + c.certMtx.RUnlock() + return c.clientCert, nil + } + c.certMtx.RUnlock() + + c.certMtx.Lock() + c.clientCert = cert + c.certMtx.Unlock() + + // The first certificate requested is not a rotation that is worth closing connections for + if !haveCert { + return cert, nil + } + + klog.V(1).Infof("certificate rotation detected, shutting down client connections to start using new credentials") + c.connDialer.CloseAll() + + return cert, nil +} + +// certsEqual compares tls Certificates, ignoring the Leaf which may get filled in dynamically +func certsEqual(left, right *tls.Certificate) bool { + if left == nil || right == nil { + return left == right + } + + if !byteMatrixEqual(left.Certificate, right.Certificate) { + return false + } + + if !reflect.DeepEqual(left.PrivateKey, right.PrivateKey) { + return false + } + + if !byteMatrixEqual(left.SignedCertificateTimestamps, right.SignedCertificateTimestamps) { + return false + } + + if !bytes.Equal(left.OCSPStaple, right.OCSPStaple) { + return false + } + + return true +} + +func byteMatrixEqual(left, right [][]byte) bool { + if len(left) != len(right) { + return false + } + + for i := range left { + if !bytes.Equal(left[i], right[i]) { + return false + } + } + return true +} + +// run starts the controller and blocks until stopCh is closed. +func (c *dynamicClientCert) Run(stopCh <-chan struct{}) { + defer utilruntime.HandleCrash() + defer c.queue.ShutDown() + + klog.Infof("Starting client certificate rotation controller") + defer klog.Infof("Shutting down client certificate rotation controller") + + go wait.Until(c.runWorker, time.Second, stopCh) + + go wait.PollImmediateUntil(CertCallbackRefreshDuration, func() (bool, error) { + c.queue.Add(workItemKey) + return false, nil + }, stopCh) + + <-stopCh +} + +func (c *dynamicClientCert) runWorker() { + for c.processNextWorkItem() { + } +} + +func (c *dynamicClientCert) processNextWorkItem() bool { + dsKey, quit := c.queue.Get() + if quit { + return false + } + defer c.queue.Done(dsKey) + + _, err := c.loadClientCert() + if err == nil { + c.queue.Forget(dsKey) + return true + } + + utilruntime.HandleError(fmt.Errorf("%v failed with : %v", dsKey, err)) + c.queue.AddRateLimited(dsKey) + + return true +} + +func (c *dynamicClientCert) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return c.loadClientCert() +} diff --git a/transport/config.go b/transport/config.go index 9e18d11d..c20a4a8f 100644 --- a/transport/config.go +++ b/transport/config.go @@ -115,9 +115,10 @@ func (c *Config) Wrap(fn WrapperFunc) { // TLSConfig holds the information needed to set up a TLS transport. type TLSConfig struct { - CAFile string // Path of the PEM-encoded server trusted root certificates. - CertFile string // Path of the PEM-encoded client certificate. - KeyFile string // Path of the PEM-encoded client key. + CAFile string // Path of the PEM-encoded server trusted root certificates. + CertFile string // Path of the PEM-encoded client certificate. + KeyFile string // Path of the PEM-encoded client key. + ReloadTLSFiles bool // Set to indicate that the original config provided files, and that they should be reloaded Insecure bool // Server should be accessed without verifying the certificate. For testing only. ServerName string // Override for the server name passed to the server for SNI and used to verify certificates. diff --git a/transport/transport.go b/transport/transport.go index cd8de982..143ebfa5 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -23,6 +23,8 @@ import ( "fmt" "io/ioutil" "net/http" + "sync" + "time" utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/klog" @@ -81,7 +83,8 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { } var staticCert *tls.Certificate - if c.HasCertAuth() { + // Treat cert as static if either key or cert was data, not a file + if c.HasCertAuth() && !c.TLS.ReloadTLSFiles { // If key/cert were provided, verify them before setting up // tlsConfig.GetClientCertificate. cert, err := tls.X509KeyPair(c.TLS.CertData, c.TLS.KeyData) @@ -91,6 +94,11 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { staticCert = &cert } + var dynamicCertLoader func() (*tls.Certificate, error) + if c.TLS.ReloadTLSFiles { + dynamicCertLoader = cachingCertificateLoader(c.TLS.CertFile, c.TLS.KeyFile) + } + if c.HasCertAuth() || c.HasCertCallback() { tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { // Note: static key/cert data always take precedence over cert @@ -98,6 +106,10 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { if staticCert != nil { return staticCert, nil } + // key/cert files lead to ReloadTLSFiles being set - takes precedence over cert callback + if dynamicCertLoader != nil { + return dynamicCertLoader() + } if c.HasCertCallback() { cert, err := c.TLS.GetCert() if err != nil { @@ -129,6 +141,11 @@ func loadTLSFiles(c *Config) error { return err } + // Check that we are purely loading from files + if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 { + c.TLS.ReloadTLSFiles = true + } + c.TLS.CertData, err = dataFromSliceOrFile(c.TLS.CertData, c.TLS.CertFile) if err != nil { return err @@ -243,3 +260,44 @@ func tryCancelRequest(rt http.RoundTripper, req *http.Request) { klog.Warningf("Unable to cancel request for %T", rt) } } + +type certificateCacheEntry struct { + cert *tls.Certificate + err error + birth time.Time +} + +// isStale returns true when this cache entry is too old to be usable +func (c *certificateCacheEntry) isStale() bool { + return time.Now().Sub(c.birth) > time.Second +} + +func newCertificateCacheEntry(certFile, keyFile string) certificateCacheEntry { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + return certificateCacheEntry{cert: &cert, err: err, birth: time.Now()} +} + +// cachingCertificateLoader ensures that we don't hammer the filesystem when opening many connections +// the underlying cert files are read at most once every second +func cachingCertificateLoader(certFile, keyFile string) func() (*tls.Certificate, error) { + current := newCertificateCacheEntry(certFile, keyFile) + var currentMtx sync.RWMutex + + return func() (*tls.Certificate, error) { + currentMtx.RLock() + if current.isStale() { + currentMtx.RUnlock() + + currentMtx.Lock() + defer currentMtx.Unlock() + + if current.isStale() { + current = newCertificateCacheEntry(certFile, keyFile) + } + } else { + defer currentMtx.RUnlock() + } + + return current.cert, current.err + } +}