diff --git a/staging/src/k8s.io/client-go/rest/request.go b/staging/src/k8s.io/client-go/rest/request.go index 956d924d340..be1b59437aa 100644 --- a/staging/src/k8s.io/client-go/rest/request.go +++ b/staging/src/k8s.io/client-go/rest/request.go @@ -835,7 +835,7 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp return err } // For connection errors and apiserver shutdown errors retry. - if net.IsConnectionReset(err) { + if net.IsConnectionReset(err) || net.IsProbableEOF(err) { // For the purpose of retry, we set the artificial "retry-after" response. // TODO: Should we clean the original response if it exists? resp = &http.Response{ diff --git a/staging/src/k8s.io/client-go/transport/BUILD b/staging/src/k8s.io/client-go/transport/BUILD index 643750e2563..3ebf8930d2d 100644 --- a/staging/src/k8s.io/client-go/transport/BUILD +++ b/staging/src/k8s.io/client-go/transport/BUILD @@ -22,6 +22,7 @@ go_library( name = "go_default_library", srcs = [ "cache.go", + "cert_rotation.go", "config.go", "round_trippers.go", "token_source.go", @@ -31,6 +32,10 @@ go_library( importpath = "k8s.io/client-go/transport", deps = [ "//staging/src/k8s.io/apimachinery/pkg/util/net:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/util/runtime:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/util/wait:go_default_library", + "//staging/src/k8s.io/client-go/util/connrotation:go_default_library", + "//staging/src/k8s.io/client-go/util/workqueue:go_default_library", "//vendor/golang.org/x/oauth2:go_default_library", "//vendor/k8s.io/klog:go_default_library", ], diff --git a/staging/src/k8s.io/client-go/transport/cache.go b/staging/src/k8s.io/client-go/transport/cache.go index 980d36ae128..36d6500f581 100644 --- a/staging/src/k8s.io/client-go/transport/cache.go +++ b/staging/src/k8s.io/client-go/transport/cache.go @@ -25,6 +25,7 @@ import ( "time" utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/apimachinery/pkg/util/wait" ) // TlsTransportCache caches TLS http.RoundTrippers different configurations. The @@ -44,6 +45,8 @@ type tlsCacheKey struct { caData string certData string keyData string + certFile string + keyFile string getCert string serverName string nextProtos string @@ -91,6 +94,16 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { KeepAlive: 30 * time.Second, }).DialContext } + + // If we use are reloading files, we need to handle certificate rotation properly + // TODO(jackkleeman): We can also add rotation here when config.HasCertCallback() is true + if config.TLS.ReloadTLSFiles { + dynamicCertDialer := certRotatingDialer(tlsConfig.GetClientCertificate, dial) + tlsConfig.GetClientCertificate = dynamicCertDialer.GetClientCertificate + dial = dynamicCertDialer.connDialer.DialContext + go dynamicCertDialer.Run(wait.NeverStop) + } + // Cache a single transport for these options c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -109,15 +122,23 @@ func tlsConfigKey(c *Config) (tlsCacheKey, error) { if err := loadTLSFiles(c); err != nil { return tlsCacheKey{}, err } - return tlsCacheKey{ + k := tlsCacheKey{ insecure: c.TLS.Insecure, caData: string(c.TLS.CAData), - certData: string(c.TLS.CertData), - keyData: string(c.TLS.KeyData), getCert: fmt.Sprintf("%p", c.TLS.GetCert), serverName: c.TLS.ServerName, nextProtos: strings.Join(c.TLS.NextProtos, ","), dial: fmt.Sprintf("%p", c.Dial), disableCompression: c.DisableCompression, - }, nil + } + + if c.TLS.ReloadTLSFiles { + k.certFile = c.TLS.CertFile + k.keyFile = c.TLS.KeyFile + } else { + k.certData = string(c.TLS.CertData) + k.keyData = string(c.TLS.KeyData) + } + + return k, nil } diff --git a/staging/src/k8s.io/client-go/transport/cert_rotation.go b/staging/src/k8s.io/client-go/transport/cert_rotation.go new file mode 100644 index 00000000000..0671b29bb2c --- /dev/null +++ b/staging/src/k8s.io/client-go/transport/cert_rotation.go @@ -0,0 +1,176 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package transport + +import ( + "bytes" + "crypto/tls" + "fmt" + "reflect" + "sync" + "time" + + utilnet "k8s.io/apimachinery/pkg/util/net" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/connrotation" + "k8s.io/client-go/util/workqueue" + "k8s.io/klog" +) + +const workItemKey = "key" + +// CertCallbackRefreshDuration is exposed so that integration tests can crank up the reload speed. +var CertCallbackRefreshDuration = 5 * time.Minute + +type reloadFunc func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + +type dynamicClientCert struct { + clientCert *tls.Certificate + certMtx sync.RWMutex + + reload reloadFunc + connDialer *connrotation.Dialer + + // queue only ever has one item, but it has nice error handling backoff/retry semantics + queue workqueue.RateLimitingInterface +} + +func certRotatingDialer(reload reloadFunc, dial utilnet.DialFunc) *dynamicClientCert { + d := &dynamicClientCert{ + reload: reload, + connDialer: connrotation.NewDialer(connrotation.DialFunc(dial)), + queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "DynamicClientCertificate"), + } + + return d +} + +// loadClientCert calls the callback and rotates connections if needed +func (c *dynamicClientCert) loadClientCert() (*tls.Certificate, error) { + cert, err := c.reload(nil) + if err != nil { + return nil, err + } + + // check to see if we have a change. If the values are the same, do nothing. + c.certMtx.RLock() + haveCert := c.clientCert != nil + if certsEqual(c.clientCert, cert) { + c.certMtx.RUnlock() + return c.clientCert, nil + } + c.certMtx.RUnlock() + + c.certMtx.Lock() + c.clientCert = cert + c.certMtx.Unlock() + + // The first certificate requested is not a rotation that is worth closing connections for + if !haveCert { + return cert, nil + } + + klog.V(1).Infof("certificate rotation detected, shutting down client connections to start using new credentials") + c.connDialer.CloseAll() + + return cert, nil +} + +// certsEqual compares tls Certificates, ignoring the Leaf which may get filled in dynamically +func certsEqual(left, right *tls.Certificate) bool { + if left == nil || right == nil { + return left == right + } + + if !byteMatrixEqual(left.Certificate, right.Certificate) { + return false + } + + if !reflect.DeepEqual(left.PrivateKey, right.PrivateKey) { + return false + } + + if !byteMatrixEqual(left.SignedCertificateTimestamps, right.SignedCertificateTimestamps) { + return false + } + + if !bytes.Equal(left.OCSPStaple, right.OCSPStaple) { + return false + } + + return true +} + +func byteMatrixEqual(left, right [][]byte) bool { + if len(left) != len(right) { + return false + } + + for i := range left { + if !bytes.Equal(left[i], right[i]) { + return false + } + } + return true +} + +// run starts the controller and blocks until stopCh is closed. +func (c *dynamicClientCert) Run(stopCh <-chan struct{}) { + defer utilruntime.HandleCrash() + defer c.queue.ShutDown() + + klog.Infof("Starting client certificate rotation controller") + defer klog.Infof("Shutting down client certificate rotation controller") + + go wait.Until(c.runWorker, time.Second, stopCh) + + go wait.PollImmediateUntil(CertCallbackRefreshDuration, func() (bool, error) { + c.queue.Add(workItemKey) + return false, nil + }, stopCh) + + <-stopCh +} + +func (c *dynamicClientCert) runWorker() { + for c.processNextWorkItem() { + } +} + +func (c *dynamicClientCert) processNextWorkItem() bool { + dsKey, quit := c.queue.Get() + if quit { + return false + } + defer c.queue.Done(dsKey) + + _, err := c.loadClientCert() + if err == nil { + c.queue.Forget(dsKey) + return true + } + + utilruntime.HandleError(fmt.Errorf("%v failed with : %v", dsKey, err)) + c.queue.AddRateLimited(dsKey) + + return true +} + +func (c *dynamicClientCert) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return c.loadClientCert() +} diff --git a/staging/src/k8s.io/client-go/transport/config.go b/staging/src/k8s.io/client-go/transport/config.go index 9e18d11d38a..c20a4a8fcb9 100644 --- a/staging/src/k8s.io/client-go/transport/config.go +++ b/staging/src/k8s.io/client-go/transport/config.go @@ -115,9 +115,10 @@ func (c *Config) Wrap(fn WrapperFunc) { // TLSConfig holds the information needed to set up a TLS transport. type TLSConfig struct { - CAFile string // Path of the PEM-encoded server trusted root certificates. - CertFile string // Path of the PEM-encoded client certificate. - KeyFile string // Path of the PEM-encoded client key. + CAFile string // Path of the PEM-encoded server trusted root certificates. + CertFile string // Path of the PEM-encoded client certificate. + KeyFile string // Path of the PEM-encoded client key. + ReloadTLSFiles bool // Set to indicate that the original config provided files, and that they should be reloaded Insecure bool // Server should be accessed without verifying the certificate. For testing only. ServerName string // Override for the server name passed to the server for SNI and used to verify certificates. diff --git a/staging/src/k8s.io/client-go/transport/transport.go b/staging/src/k8s.io/client-go/transport/transport.go index cd8de982859..143ebfa5c83 100644 --- a/staging/src/k8s.io/client-go/transport/transport.go +++ b/staging/src/k8s.io/client-go/transport/transport.go @@ -23,6 +23,8 @@ import ( "fmt" "io/ioutil" "net/http" + "sync" + "time" utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/klog" @@ -81,7 +83,8 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { } var staticCert *tls.Certificate - if c.HasCertAuth() { + // Treat cert as static if either key or cert was data, not a file + if c.HasCertAuth() && !c.TLS.ReloadTLSFiles { // If key/cert were provided, verify them before setting up // tlsConfig.GetClientCertificate. cert, err := tls.X509KeyPair(c.TLS.CertData, c.TLS.KeyData) @@ -91,6 +94,11 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { staticCert = &cert } + var dynamicCertLoader func() (*tls.Certificate, error) + if c.TLS.ReloadTLSFiles { + dynamicCertLoader = cachingCertificateLoader(c.TLS.CertFile, c.TLS.KeyFile) + } + if c.HasCertAuth() || c.HasCertCallback() { tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { // Note: static key/cert data always take precedence over cert @@ -98,6 +106,10 @@ func TLSConfigFor(c *Config) (*tls.Config, error) { if staticCert != nil { return staticCert, nil } + // key/cert files lead to ReloadTLSFiles being set - takes precedence over cert callback + if dynamicCertLoader != nil { + return dynamicCertLoader() + } if c.HasCertCallback() { cert, err := c.TLS.GetCert() if err != nil { @@ -129,6 +141,11 @@ func loadTLSFiles(c *Config) error { return err } + // Check that we are purely loading from files + if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 { + c.TLS.ReloadTLSFiles = true + } + c.TLS.CertData, err = dataFromSliceOrFile(c.TLS.CertData, c.TLS.CertFile) if err != nil { return err @@ -243,3 +260,44 @@ func tryCancelRequest(rt http.RoundTripper, req *http.Request) { klog.Warningf("Unable to cancel request for %T", rt) } } + +type certificateCacheEntry struct { + cert *tls.Certificate + err error + birth time.Time +} + +// isStale returns true when this cache entry is too old to be usable +func (c *certificateCacheEntry) isStale() bool { + return time.Now().Sub(c.birth) > time.Second +} + +func newCertificateCacheEntry(certFile, keyFile string) certificateCacheEntry { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + return certificateCacheEntry{cert: &cert, err: err, birth: time.Now()} +} + +// cachingCertificateLoader ensures that we don't hammer the filesystem when opening many connections +// the underlying cert files are read at most once every second +func cachingCertificateLoader(certFile, keyFile string) func() (*tls.Certificate, error) { + current := newCertificateCacheEntry(certFile, keyFile) + var currentMtx sync.RWMutex + + return func() (*tls.Certificate, error) { + currentMtx.RLock() + if current.isStale() { + currentMtx.RUnlock() + + currentMtx.Lock() + defer currentMtx.Unlock() + + if current.isStale() { + current = newCertificateCacheEntry(certFile, keyFile) + } + } else { + defer currentMtx.RUnlock() + } + + return current.cert, current.err + } +} diff --git a/test/integration/client/BUILD b/test/integration/client/BUILD index 8f6fd280328..b237151e0ed 100644 --- a/test/integration/client/BUILD +++ b/test/integration/client/BUILD @@ -9,6 +9,7 @@ go_test( name = "go_default_test", size = "large", srcs = [ + "cert_rotation_test.go", "client_test.go", "dynamic_client_test.go", "main_test.go", @@ -32,9 +33,13 @@ go_test( "//staging/src/k8s.io/client-go/dynamic:go_default_library", "//staging/src/k8s.io/client-go/kubernetes:go_default_library", "//staging/src/k8s.io/client-go/kubernetes/scheme:go_default_library", + "//staging/src/k8s.io/client-go/transport:go_default_library", + "//staging/src/k8s.io/client-go/util/cert:go_default_library", "//staging/src/k8s.io/component-base/version:go_default_library", "//test/integration/framework:go_default_library", + "//test/utils:go_default_library", "//test/utils/image:go_default_library", + "//vendor/github.com/stretchr/testify/assert:go_default_library", ], ) diff --git a/test/integration/client/cert_rotation_test.go b/test/integration/client/cert_rotation_test.go new file mode 100644 index 00000000000..ebeae2689d2 --- /dev/null +++ b/test/integration/client/cert_rotation_test.go @@ -0,0 +1,214 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "io/ioutil" + "math" + "math/big" + "os" + "path" + "testing" + "time" + + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + clientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/transport" + "k8s.io/client-go/util/cert" + apiservertesting "k8s.io/kubernetes/cmd/kube-apiserver/app/testing" + "k8s.io/kubernetes/test/integration/framework" + "k8s.io/kubernetes/test/utils" +) + +func TestCertRotation(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + + clientSigningKey, err := utils.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + clientSigningCert, err := cert.NewSelfSignedCACert(cert.Config{CommonName: "client-ca"}, clientSigningKey) + if err != nil { + t.Fatal(err) + } + + transport.CertCallbackRefreshDuration = 1 * time.Second + + certDir := os.TempDir() + clientCAFilename := path.Join(certDir, "ca.crt") + + if err := ioutil.WriteFile(clientCAFilename, utils.EncodeCertPEM(clientSigningCert), 0644); err != nil { + t.Fatal(err) + } + + server := apiservertesting.StartTestServerOrDie(t, apiservertesting.NewDefaultTestServerOptions(), []string{ + "--client-ca-file=" + clientCAFilename, + }, framework.SharedEtcd()) + defer server.TearDownFn() + + writeCerts(t, clientSigningCert, clientSigningKey, certDir, 30*time.Second) + + kubeconfig := server.ClientConfig + kubeconfig.CertFile = path.Join(certDir, "client.crt") + kubeconfig.KeyFile = path.Join(certDir, "client.key") + kubeconfig.BearerToken = "" + + client := clientset.NewForConfigOrDie(kubeconfig) + ctx := context.Background() + + w, err := client.CoreV1().ServiceAccounts("default").Watch(ctx, v1.ListOptions{}) + if err != nil { + t.Fatal(err) + } + + select { + case <-w.ResultChan(): + t.Fatal("Watch closed before rotation") + default: + } + + writeCerts(t, clientSigningCert, clientSigningKey, certDir, 5*time.Minute) + + time.Sleep(10 * time.Second) + + // Should have had a rotation; connections will have been closed + select { + case _, ok := <-w.ResultChan(): + assert.Equal(t, false, ok) + default: + t.Fatal("Watch wasn't closed despite rotation") + } + + // Wait for old cert to expire (30s) + time.Sleep(30 * time.Second) + + // Ensure we make requests with the new cert + _, err = client.CoreV1().ServiceAccounts("default").List(ctx, v1.ListOptions{}) + if err != nil { + t.Fatal(err) + } +} + +func TestCertRotationContinuousRequests(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + + clientSigningKey, err := utils.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + clientSigningCert, err := cert.NewSelfSignedCACert(cert.Config{CommonName: "client-ca"}, clientSigningKey) + if err != nil { + t.Fatal(err) + } + + transport.CertCallbackRefreshDuration = 1 * time.Second + + certDir := os.TempDir() + clientCAFilename := path.Join(certDir, "ca.crt") + + if err := ioutil.WriteFile(clientCAFilename, utils.EncodeCertPEM(clientSigningCert), 0644); err != nil { + t.Fatal(err) + } + + server := apiservertesting.StartTestServerOrDie(t, apiservertesting.NewDefaultTestServerOptions(), []string{ + "--client-ca-file=" + clientCAFilename, + }, framework.SharedEtcd()) + defer server.TearDownFn() + + writeCerts(t, clientSigningCert, clientSigningKey, certDir, 30*time.Second) + + kubeconfig := server.ClientConfig + kubeconfig.CertFile = path.Join(certDir, "client.crt") + kubeconfig.KeyFile = path.Join(certDir, "client.key") + kubeconfig.BearerToken = "" + + client := clientset.NewForConfigOrDie(kubeconfig) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(10 * time.Second) + + writeCerts(t, clientSigningCert, clientSigningKey, certDir, 5*time.Minute) + + // Wait for old cert to expire (30s) + time.Sleep(30 * time.Second) + cancel() + }() + + for range time.Tick(time.Second) { + _, err := client.CoreV1().ServiceAccounts("default").List(ctx, v1.ListOptions{}) + if err != nil { + if err == ctx.Err() { + return + } + + t.Fatal(err) + } + } +} + +func writeCerts(t *testing.T, clientSigningCert *x509.Certificate, clientSigningKey *rsa.PrivateKey, certDir string, duration time.Duration) { + clientKey, err := utils.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + + privBytes, err := x509.MarshalPKCS8PrivateKey(clientKey) + if err != nil { + t.Fatal(err) + } + + if err := ioutil.WriteFile(path.Join(certDir, "client.key"), pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}), 0666); err != nil { + t.Fatal(err) + } + + serial, err := rand.Int(rand.Reader, new(big.Int).SetInt64(math.MaxInt64)) + if err != nil { + t.Fatal(err) + } + + certTmpl := x509.Certificate{ + Subject: pkix.Name{ + CommonName: "foo", + Organization: []string{"system:masters"}, + }, + SerialNumber: serial, + NotBefore: clientSigningCert.NotBefore, + NotAfter: time.Now().Add(duration).UTC(), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + certDERBytes, err := x509.CreateCertificate(rand.Reader, &certTmpl, clientSigningCert, clientKey.Public(), clientSigningKey) + if err != nil { + t.Fatal(err) + } + + if err := ioutil.WriteFile(path.Join(certDir, "client.crt"), pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDERBytes}), 0666); err != nil { + t.Fatal(err) + } +}