diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/exec/exec.go b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/exec/exec.go index af21c499537..4957a461a66 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/exec/exec.go +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/exec/exec.go @@ -18,7 +18,6 @@ package exec import ( "bytes" - "context" "crypto/tls" "crypto/x509" "errors" @@ -52,7 +51,6 @@ import ( ) const execInfoEnv = "KUBERNETES_EXEC_INFO" -const onRotateListWarningLength = 1000 const installHintVerboseHelp = ` It looks like you are trying to use a client-go credential plugin that is not installed. @@ -177,6 +175,12 @@ func newAuthenticator(c *cache, config *api.ExecConfig, cluster *clientauthentic return nil, fmt.Errorf("exec plugin: invalid apiVersion %q", config.APIVersion) } + connTracker := connrotation.NewConnectionTracker() + defaultDialer := connrotation.NewDialerWithTracker( + (&net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second}).DialContext, + connTracker, + ) + a := &Authenticator{ cmd: config.Command, args: config.Args, @@ -196,6 +200,9 @@ func newAuthenticator(c *cache, config *api.ExecConfig, cluster *clientauthentic interactive: terminal.IsTerminal(int(os.Stdout.Fd())), now: time.Now, environ: os.Environ, + + defaultDialer: defaultDialer, + connTracker: connTracker, } for _, env := range config.Env { @@ -229,6 +236,11 @@ type Authenticator struct { now func() time.Time environ func() []string + // defaultDialer is used for clients which don't specify a custom dialer + defaultDialer *connrotation.Dialer + // connTracker tracks all connections opened that we need to close when rotating a client certificate + connTracker *connrotation.ConnectionTracker + // Cached results. // // The mutex also guards calling the plugin. Since the plugin could be @@ -236,8 +248,6 @@ type Authenticator struct { mu sync.Mutex cachedCreds *credentials exp time.Time - - onRotateList []func() } type credentials struct { @@ -266,20 +276,12 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error { } c.TLS.GetCert = a.cert - var dial func(ctx context.Context, network, addr string) (net.Conn, error) + var d *connrotation.Dialer if c.Dial != nil { - dial = c.Dial + // if c has a custom dialer, we have to wrap it + d = connrotation.NewDialerWithTracker(c.Dial, a.connTracker) } else { - dial = (&net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second}).DialContext - } - d := connrotation.NewDialer(dial) - - a.mu.Lock() - defer a.mu.Unlock() - a.onRotateList = append(a.onRotateList, d.CloseAll) - onRotateListLength := len(a.onRotateList) - if onRotateListLength > onRotateListWarningLength { - klog.Warningf("constructing many client instances from the same exec auth config can cause performance problems during cert rotation and can exhaust available network connections; %d clients constructed calling %q", onRotateListLength, a.cmd) + d = a.defaultDialer } c.Dial = d.DialContext @@ -458,9 +460,7 @@ func (a *Authenticator) refreshCredsLocked(r *clientauthentication.Response) err if oldCreds.cert != nil && oldCreds.cert.Leaf != nil { metrics.ClientCertRotationAge.Observe(time.Now().Sub(oldCreds.cert.Leaf.NotBefore)) } - for _, onRotate := range a.onRotateList { - onRotate() - } + a.connTracker.CloseAll() } expiry := time.Time{} diff --git a/staging/src/k8s.io/client-go/util/connrotation/connrotation.go b/staging/src/k8s.io/client-go/util/connrotation/connrotation.go index f98faee47d5..2b9bf72bde6 100644 --- a/staging/src/k8s.io/client-go/util/connrotation/connrotation.go +++ b/staging/src/k8s.io/client-go/util/connrotation/connrotation.go @@ -33,18 +33,40 @@ type DialFunc func(ctx context.Context, network, address string) (net.Conn, erro // Dialer opens connections through Dial and tracks them. type Dialer struct { dial DialFunc + *ConnectionTracker +} +// NewDialer creates a new Dialer instance. +// Equivalent to NewDialerWithTracker(dial, nil). +func NewDialer(dial DialFunc) *Dialer { + return NewDialerWithTracker(dial, nil) +} + +// NewDialerWithTracker creates a new Dialer instance. +// +// If dial is not nil, it will be used to create new underlying connections. +// Otherwise net.DialContext is used. +// If tracker is not nil, it will be used to track new underlying connections. +// Otherwise NewConnectionTracker() is used. +func NewDialerWithTracker(dial DialFunc, tracker *ConnectionTracker) *Dialer { + if tracker == nil { + tracker = NewConnectionTracker() + } + return &Dialer{ + dial: dial, + ConnectionTracker: tracker, + } +} + +// ConnectionTracker keeps track of opened connections +type ConnectionTracker struct { mu sync.Mutex conns map[*closableConn]struct{} } -// NewDialer creates a new Dialer instance. -// -// If dial is not nil, it will be used to create new underlying connections. -// Otherwise net.DialContext is used. -func NewDialer(dial DialFunc) *Dialer { - return &Dialer{ - dial: dial, +// NewConnectionTracker returns a connection tracker for use with NewDialerWithTracker +func NewConnectionTracker() *ConnectionTracker { + return &ConnectionTracker{ conns: make(map[*closableConn]struct{}), } } @@ -52,17 +74,40 @@ func NewDialer(dial DialFunc) *Dialer { // CloseAll forcibly closes all tracked connections. // // Note: new connections may get created before CloseAll returns. -func (d *Dialer) CloseAll() { - d.mu.Lock() - conns := d.conns - d.conns = make(map[*closableConn]struct{}) - d.mu.Unlock() +func (c *ConnectionTracker) CloseAll() { + c.mu.Lock() + conns := c.conns + c.conns = make(map[*closableConn]struct{}) + c.mu.Unlock() for conn := range conns { conn.Close() } } +// Track adds the connection to the list of tracked connections, +// and returns a wrapped copy of the connection that stops tracking the connection +// when it is closed. +func (c *ConnectionTracker) Track(conn net.Conn) net.Conn { + closable := &closableConn{Conn: conn} + + // When the connection is closed, remove it from the map. This will + // be no-op if the connection isn't in the map, e.g. if CloseAll() + // is called. + closable.onClose = func() { + c.mu.Lock() + delete(c.conns, closable) + c.mu.Unlock() + } + + // Start tracking the connection + c.mu.Lock() + c.conns[closable] = struct{}{} + c.mu.Unlock() + + return closable +} + // Dial creates a new tracked connection. func (d *Dialer) Dial(network, address string) (net.Conn, error) { return d.DialContext(context.Background(), network, address) @@ -74,24 +119,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, err } - - closable := &closableConn{Conn: conn} - - // When the connection is closed, remove it from the map. This will - // be no-op if the connection isn't in the map, e.g. if CloseAll() - // is called. - closable.onClose = func() { - d.mu.Lock() - delete(d.conns, closable) - d.mu.Unlock() - } - - // Start tracking the connection - d.mu.Lock() - d.conns[closable] = struct{}{} - d.mu.Unlock() - - return closable, nil + return d.ConnectionTracker.Track(conn), nil } type closableConn struct {