diff --git a/util/connrotation/connrotation.go b/util/connrotation/connrotation.go new file mode 100644 index 00000000..235a9e01 --- /dev/null +++ b/util/connrotation/connrotation.go @@ -0,0 +1,105 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package connrotation implements a connection dialer that tracks and can close +// all created connections. +// +// This is used for credential rotation of long-lived connections, when there's +// no way to re-authenticate on a live connection. +package connrotation + +import ( + "context" + "net" + "sync" +) + +// DialFunc is a shorthand for signature of net.DialContext. +type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// Dialer opens connections through Dial and tracks them. +type Dialer struct { + dial DialFunc + + mu sync.Mutex + conns map[*closableConn]struct{} +} + +// NewDialer creates a new Dialer instance. +// +// If dial is not nil, it will be used to create new underlying connections. +// Otherwise net.DialContext is used. +func NewDialer(dial DialFunc) *Dialer { + return &Dialer{ + dial: dial, + conns: make(map[*closableConn]struct{}), + } +} + +// CloseAll forcibly closes all tracked connections. +// +// Note: new connections may get created before CloseAll returns. +func (d *Dialer) CloseAll() { + d.mu.Lock() + conns := d.conns + d.conns = make(map[*closableConn]struct{}) + d.mu.Unlock() + + for conn := range conns { + conn.Close() + } +} + +// Dial creates a new tracked connection. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// DialContext creates a new tracked connection. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.dial(ctx, network, address) + if err != nil { + return nil, err + } + + closable := &closableConn{Conn: conn} + + // Start tracking the connection + d.mu.Lock() + d.conns[closable] = struct{}{} + d.mu.Unlock() + + // When the connection is closed, remove it from the map. This will + // be no-op if the connection isn't in the map, e.g. if CloseAll() + // is called. + closable.onClose = func() { + d.mu.Lock() + delete(d.conns, closable) + d.mu.Unlock() + } + + return closable, nil +} + +type closableConn struct { + onClose func() + net.Conn +} + +func (c *closableConn) Close() error { + go c.onClose() + return c.Conn.Close() +} diff --git a/util/connrotation/connrotation_test.go b/util/connrotation/connrotation_test.go new file mode 100644 index 00000000..a618f296 --- /dev/null +++ b/util/connrotation/connrotation_test.go @@ -0,0 +1,61 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package connrotation + +import ( + "context" + "net" + "testing" + "time" +) + +func TestCloseAll(t *testing.T) { + closed := make(chan struct{}) + dialFn := func(ctx context.Context, network, address string) (net.Conn, error) { + return closeOnlyConn{onClose: func() { closed <- struct{}{} }}, nil + } + dialer := NewDialer(dialFn) + + const numConns = 10 + + // Outer loop to ensure Dialer is re-usable after CloseAll. + for i := 0; i < 5; i++ { + for j := 0; j < numConns; j++ { + if _, err := dialer.Dial("", ""); err != nil { + t.Fatal(err) + } + } + dialer.CloseAll() + for j := 0; j < numConns; j++ { + select { + case <-closed: + case <-time.After(time.Second): + t.Fatalf("iteration %d: 1s after CloseAll only %d/%d connections closed", i, j, numConns) + } + } + } +} + +type closeOnlyConn struct { + net.Conn + onClose func() +} + +func (c closeOnlyConn) Close() error { + go c.onClose() + return nil +}