Set up connection onClose prior to adding to connection map

Kubernetes-commit: aa4113d777dd6c699233e0b6d903e9734e182686
This commit is contained in:
Jordan Liggitt 2020-02-12 11:14:22 -05:00 committed by Kubernetes Publisher
parent e38a845233
commit 03953c1a93
2 changed files with 74 additions and 5 deletions

View File

@ -77,11 +77,6 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
closable := &closableConn{Conn: conn}
// Start tracking the connection
d.mu.Lock()
d.conns[closable] = struct{}{}
d.mu.Unlock()
// When the connection is closed, remove it from the map. This will
// be no-op if the connection isn't in the map, e.g. if CloseAll()
// is called.
@ -91,6 +86,11 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
d.mu.Unlock()
}
// Start tracking the connection
d.mu.Lock()
d.conns[closable] = struct{}{}
d.mu.Unlock()
return closable, nil
}

View File

@ -19,6 +19,8 @@ package connrotation
import (
"context"
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
@ -50,6 +52,73 @@ func TestCloseAll(t *testing.T) {
}
}
// TestCloseAllRace ensures CloseAll works with connections being simultaneously dialed
func TestCloseAllRace(t *testing.T) {
conns := int64(0)
dialer := NewDialer(func(ctx context.Context, network, address string) (net.Conn, error) {
return closeOnlyConn{onClose: func() { atomic.AddInt64(&conns, -1) }}, nil
})
done := make(chan struct{})
wg := &sync.WaitGroup{}
// Close all as fast as we can
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-done:
return
default:
dialer.CloseAll()
}
}
}()
// Dial as fast as we can
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-done:
return
default:
if _, err := dialer.Dial("", ""); err != nil {
t.Error(err)
return
}
atomic.AddInt64(&conns, 1)
}
}
}()
// Soak to ensure no races
time.Sleep(time.Second)
// Signal completion
close(done)
// Wait for goroutines
wg.Wait()
// Ensure CloseAll ran after all dials
dialer.CloseAll()
// Expect all connections to close within 5 seconds
for start := time.Now(); time.Now().Sub(start) < 5*time.Second; time.Sleep(10 * time.Millisecond) {
// Ensure all connections were closed
if c := atomic.LoadInt64(&conns); c == 0 {
break
} else {
t.Logf("got %d open connections, want 0, will retry", c)
}
}
// Ensure all connections were closed
if c := atomic.LoadInt64(&conns); c != 0 {
t.Fatalf("got %d open connections, want 0", c)
}
}
type closeOnlyConn struct {
net.Conn
onClose func()