forked from github/dynamiclistener
Add option to close connections on cert change
This commit is contained in:
parent
3f92468568
commit
8545ce98db
63
listener.go
63
listener.go
@ -59,6 +59,10 @@ func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, c
|
||||
}
|
||||
dynamicListener.tlsConfig.GetCertificate = dynamicListener.getCertificate
|
||||
|
||||
if config.CloseConnOnCertChange {
|
||||
dynamicListener.conns = map[int]*closeWrapper{}
|
||||
}
|
||||
|
||||
if setter, ok := storage.(SetFactory); ok {
|
||||
setter.SetFactory(dynamicListener.factory)
|
||||
}
|
||||
@ -82,17 +86,22 @@ func (c *cancelClose) Close() error {
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
CN string
|
||||
Organization []string
|
||||
TLSConfig *tls.Config
|
||||
SANs []string
|
||||
ExpirationDaysCheck int
|
||||
CN string
|
||||
Organization []string
|
||||
TLSConfig *tls.Config
|
||||
SANs []string
|
||||
ExpirationDaysCheck int
|
||||
CloseConnOnCertChange bool
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
sync.RWMutex
|
||||
net.Listener
|
||||
|
||||
conns map[int]*closeWrapper
|
||||
connID int
|
||||
connLock sync.Mutex
|
||||
|
||||
factory TLSFactory
|
||||
storage TLSStorage
|
||||
version string
|
||||
@ -194,9 +203,45 @@ func (l *listener) Accept() (net.Conn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if l.conns != nil {
|
||||
conn = l.wrap(conn)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *listener) wrap(conn net.Conn) net.Conn {
|
||||
l.connLock.Lock()
|
||||
defer l.connLock.Unlock()
|
||||
l.connID++
|
||||
|
||||
wrapper := &closeWrapper{
|
||||
Conn: conn,
|
||||
id: l.connID,
|
||||
l: l,
|
||||
}
|
||||
l.conns[l.connID] = wrapper
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
type closeWrapper struct {
|
||||
net.Conn
|
||||
id int
|
||||
l *listener
|
||||
}
|
||||
|
||||
func (c *closeWrapper) close() error {
|
||||
delete(c.l.conns, c.id)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *closeWrapper) Close() error {
|
||||
c.l.Lock()
|
||||
defer c.l.Unlock()
|
||||
return c.close()
|
||||
}
|
||||
|
||||
func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if hello.ServerName != "" {
|
||||
if err := l.updateCert(hello.ServerName); err != nil {
|
||||
@ -238,6 +283,14 @@ func (l *listener) updateCert(cn ...string) error {
|
||||
l.version = ""
|
||||
}
|
||||
|
||||
if l.conns != nil {
|
||||
l.connLock.Lock()
|
||||
for _, conn := range l.conns {
|
||||
_ = conn.close()
|
||||
}
|
||||
l.connLock.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user