diff --git a/listener.go b/listener.go index 1914514..28db3ff 100644 --- a/listener.go +++ b/listener.go @@ -59,6 +59,10 @@ func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, c } dynamicListener.tlsConfig.GetCertificate = dynamicListener.getCertificate + if config.CloseConnOnCertChange { + dynamicListener.conns = map[int]*closeWrapper{} + } + if setter, ok := storage.(SetFactory); ok { setter.SetFactory(dynamicListener.factory) } @@ -82,17 +86,22 @@ func (c *cancelClose) Close() error { } type Config struct { - CN string - Organization []string - TLSConfig *tls.Config - SANs []string - ExpirationDaysCheck int + CN string + Organization []string + TLSConfig *tls.Config + SANs []string + ExpirationDaysCheck int + CloseConnOnCertChange bool } type listener struct { sync.RWMutex net.Listener + conns map[int]*closeWrapper + connID int + connLock sync.Mutex + factory TLSFactory storage TLSStorage version string @@ -194,9 +203,45 @@ func (l *listener) Accept() (net.Conn, error) { } } + if l.conns != nil { + conn = l.wrap(conn) + } + return conn, nil } +func (l *listener) wrap(conn net.Conn) net.Conn { + l.connLock.Lock() + defer l.connLock.Unlock() + l.connID++ + + wrapper := &closeWrapper{ + Conn: conn, + id: l.connID, + l: l, + } + l.conns[l.connID] = wrapper + + return wrapper +} + +type closeWrapper struct { + net.Conn + id int + l *listener +} + +func (c *closeWrapper) close() error { + delete(c.l.conns, c.id) + return c.Conn.Close() +} + +func (c *closeWrapper) Close() error { + c.l.Lock() + defer c.l.Unlock() + return c.close() +} + func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { if hello.ServerName != "" { if err := l.updateCert(hello.ServerName); err != nil { @@ -238,6 +283,14 @@ func (l *listener) updateCert(cn ...string) error { l.version = "" } + if l.conns != nil { + l.connLock.Lock() + for _, conn := range l.conns { + _ = conn.close() + } + l.connLock.Unlock() + } + return nil }