diff --git a/pkg/kubelet/certificate/BUILD b/pkg/kubelet/certificate/BUILD index 23ec4341c53..6910df4d938 100644 --- a/pkg/kubelet/certificate/BUILD +++ b/pkg/kubelet/certificate/BUILD @@ -26,6 +26,7 @@ go_library( "//vendor/k8s.io/client-go/kubernetes/typed/certificates/v1beta1:go_default_library", "//vendor/k8s.io/client-go/rest:go_default_library", "//vendor/k8s.io/client-go/util/certificate:go_default_library", + "//vendor/k8s.io/client-go/util/connrotation:go_default_library", ], ) diff --git a/pkg/kubelet/certificate/transport.go b/pkg/kubelet/certificate/transport.go index f980cb6eb14..76caa20a753 100644 --- a/pkg/kubelet/certificate/transport.go +++ b/pkg/kubelet/certificate/transport.go @@ -17,12 +17,10 @@ limitations under the License. package certificate import ( - "context" "crypto/tls" "fmt" "net" "net/http" - "sync" "time" "github.com/golang/glog" @@ -31,6 +29,7 @@ import ( "k8s.io/apimachinery/pkg/util/wait" restclient "k8s.io/client-go/rest" "k8s.io/client-go/util/certificate" + "k8s.io/client-go/util/connrotation" ) // UpdateTransport instruments a restconfig with a transport that dynamically uses @@ -64,11 +63,7 @@ func updateTransport(stopCh <-chan struct{}, period time.Duration, clientConfig return nil, fmt.Errorf("there is already a transport or dialer configured") } - // Custom dialer that will track all connections it creates. - t := &connTracker{ - dialer: &net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second}, - conns: make(map[*closableConn]struct{}), - } + d := connrotation.NewDialer((&net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second}).DialContext) tlsConfig, err := restclient.TLSConfigFor(clientConfig) if err != nil { @@ -128,7 +123,7 @@ func updateTransport(stopCh <-chan struct{}, period time.Duration, clientConfig // to reperform its TLS handshake with new cert. // // See: https://github.com/kubernetes-incubator/bootkube/pull/663#issuecomment-318506493 - t.closeAllConns() + d.CloseAll() }, period, stopCh) } @@ -137,7 +132,7 @@ func updateTransport(stopCh <-chan struct{}, period time.Duration, clientConfig TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: tlsConfig, MaxIdleConnsPerHost: 25, - DialContext: t.DialContext, // Use custom dialer. + DialContext: d.DialContext, // Use custom dialer. }) // Zero out all existing TLS options since our new transport enforces them. @@ -149,60 +144,5 @@ func updateTransport(stopCh <-chan struct{}, period time.Duration, clientConfig clientConfig.CAFile = "" clientConfig.Insecure = false - return t.closeAllConns, nil -} - -// connTracker is a dialer that tracks all open connections it creates. -type connTracker struct { - dialer *net.Dialer - - mu sync.Mutex - conns map[*closableConn]struct{} -} - -// closeAllConns forcibly closes all tracked connections. -func (c *connTracker) closeAllConns() { - c.mu.Lock() - conns := c.conns - c.conns = make(map[*closableConn]struct{}) - c.mu.Unlock() - - for conn := range conns { - conn.Close() - } -} - -func (c *connTracker) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := c.dialer.DialContext(ctx, network, address) - if err != nil { - return nil, err - } - - closable := &closableConn{Conn: conn} - - // Start tracking the connection - c.mu.Lock() - c.conns[closable] = struct{}{} - c.mu.Unlock() - - // When the connection is closed, remove it from the map. This will - // be no-op if the connection isn't in the map, e.g. if closeAllConns() - // is called. - closable.onClose = func() { - c.mu.Lock() - delete(c.conns, closable) - c.mu.Unlock() - } - - return closable, nil -} - -type closableConn struct { - onClose func() - net.Conn -} - -func (c *closableConn) Close() error { - go c.onClose() - return c.Conn.Close() + return d.CloseAll, nil } diff --git a/staging/BUILD b/staging/BUILD index c1fae9e041a..1f6015e3520 100644 --- a/staging/BUILD +++ b/staging/BUILD @@ -162,6 +162,7 @@ filegroup( "//staging/src/k8s.io/client-go/util/buffer:all-srcs", "//staging/src/k8s.io/client-go/util/cert:all-srcs", "//staging/src/k8s.io/client-go/util/certificate:all-srcs", + "//staging/src/k8s.io/client-go/util/connrotation:all-srcs", "//staging/src/k8s.io/client-go/util/exec:all-srcs", "//staging/src/k8s.io/client-go/util/flowcontrol:all-srcs", "//staging/src/k8s.io/client-go/util/homedir:all-srcs", diff --git a/staging/src/k8s.io/client-go/util/connrotation/BUILD b/staging/src/k8s.io/client-go/util/connrotation/BUILD new file mode 100644 index 00000000000..5744cfd1e51 --- /dev/null +++ b/staging/src/k8s.io/client-go/util/connrotation/BUILD @@ -0,0 +1,28 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = ["connrotation.go"], + importpath = "k8s.io/client-go/util/connrotation", + visibility = ["//visibility:public"], +) + +go_test( + name = "go_default_test", + srcs = ["connrotation_test.go"], + embed = [":go_default_library"], +) + +filegroup( + name = "package-srcs", + srcs = glob(["**"]), + tags = ["automanaged"], + visibility = ["//visibility:private"], +) + +filegroup( + name = "all-srcs", + srcs = [":package-srcs"], + tags = ["automanaged"], + visibility = ["//visibility:public"], +) diff --git a/staging/src/k8s.io/client-go/util/connrotation/connrotation.go b/staging/src/k8s.io/client-go/util/connrotation/connrotation.go new file mode 100644 index 00000000000..235a9e01987 --- /dev/null +++ b/staging/src/k8s.io/client-go/util/connrotation/connrotation.go @@ -0,0 +1,105 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package connrotation implements a connection dialer that tracks and can close +// all created connections. +// +// This is used for credential rotation of long-lived connections, when there's +// no way to re-authenticate on a live connection. +package connrotation + +import ( + "context" + "net" + "sync" +) + +// DialFunc is a shorthand for signature of net.DialContext. +type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// Dialer opens connections through Dial and tracks them. +type Dialer struct { + dial DialFunc + + mu sync.Mutex + conns map[*closableConn]struct{} +} + +// NewDialer creates a new Dialer instance. +// +// If dial is not nil, it will be used to create new underlying connections. +// Otherwise net.DialContext is used. +func NewDialer(dial DialFunc) *Dialer { + return &Dialer{ + dial: dial, + conns: make(map[*closableConn]struct{}), + } +} + +// CloseAll forcibly closes all tracked connections. +// +// Note: new connections may get created before CloseAll returns. +func (d *Dialer) CloseAll() { + d.mu.Lock() + conns := d.conns + d.conns = make(map[*closableConn]struct{}) + d.mu.Unlock() + + for conn := range conns { + conn.Close() + } +} + +// Dial creates a new tracked connection. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// DialContext creates a new tracked connection. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.dial(ctx, network, address) + if err != nil { + return nil, err + } + + closable := &closableConn{Conn: conn} + + // Start tracking the connection + d.mu.Lock() + d.conns[closable] = struct{}{} + d.mu.Unlock() + + // When the connection is closed, remove it from the map. This will + // be no-op if the connection isn't in the map, e.g. if CloseAll() + // is called. + closable.onClose = func() { + d.mu.Lock() + delete(d.conns, closable) + d.mu.Unlock() + } + + return closable, nil +} + +type closableConn struct { + onClose func() + net.Conn +} + +func (c *closableConn) Close() error { + go c.onClose() + return c.Conn.Close() +} diff --git a/staging/src/k8s.io/client-go/util/connrotation/connrotation_test.go b/staging/src/k8s.io/client-go/util/connrotation/connrotation_test.go new file mode 100644 index 00000000000..a618f2961ba --- /dev/null +++ b/staging/src/k8s.io/client-go/util/connrotation/connrotation_test.go @@ -0,0 +1,61 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package connrotation + +import ( + "context" + "net" + "testing" + "time" +) + +func TestCloseAll(t *testing.T) { + closed := make(chan struct{}) + dialFn := func(ctx context.Context, network, address string) (net.Conn, error) { + return closeOnlyConn{onClose: func() { closed <- struct{}{} }}, nil + } + dialer := NewDialer(dialFn) + + const numConns = 10 + + // Outer loop to ensure Dialer is re-usable after CloseAll. + for i := 0; i < 5; i++ { + for j := 0; j < numConns; j++ { + if _, err := dialer.Dial("", ""); err != nil { + t.Fatal(err) + } + } + dialer.CloseAll() + for j := 0; j < numConns; j++ { + select { + case <-closed: + case <-time.After(time.Second): + t.Fatalf("iteration %d: 1s after CloseAll only %d/%d connections closed", i, j, numConns) + } + } + } +} + +type closeOnlyConn struct { + net.Conn + onClose func() +} + +func (c closeOnlyConn) Close() error { + go c.onClose() + return nil +}