diff --git a/listener.go b/listener.go index bd310ee..7ba32fb 100644 --- a/listener.go +++ b/listener.go @@ -275,8 +275,9 @@ func (l *listener) wrap(conn net.Conn) net.Conn { type closeWrapper struct { net.Conn - id int - l *listener + id int + l *listener + ready bool } func (c *closeWrapper) close() error { @@ -291,13 +292,14 @@ func (c *closeWrapper) Close() error { } func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + newConn := hello.Conn if hello.ServerName != "" { if err := l.updateCert(hello.ServerName); err != nil { return nil, err } } - return l.loadCert() + return l.loadCert(newConn.(*closeWrapper)) } func (l *listener) updateCert(cn ...string) error { @@ -339,7 +341,7 @@ func (l *listener) updateCert(cn ...string) error { return nil } -func (l *listener) loadCert() (*tls.Certificate, error) { +func (l *listener) loadCert(currentConn *closeWrapper) (*tls.Certificate, error) { l.RLock() defer l.RUnlock() @@ -373,8 +375,13 @@ func (l *listener) loadCert() (*tls.Certificate, error) { if l.conns != nil && l.cert != nil { l.connLock.Lock() for _, conn := range l.conns { + // Don't close a connection that's in the middle of completing a TLS handshake + if !conn.ready { + continue + } _ = conn.close() } + l.conns[currentConn.id].ready = true l.connLock.Unlock() }