forked from github/dynamiclistener
Without this change, if a cert is updated (e.g. to add CNs) while the listener is in the middle of Accept()ing a new connection, the connection gets dropped, we'll see a message like this in the server logs: http: TLS handshake error from 127.0.0.1:51232: write tcp 127.0.7.1:8443->127.0.0.1:51232: use of closed network connection and the client (like a browser) won't necessarily reconnect. This change modifies the GetCertificate routine in the listener's tls.Config to keep track of the state of the incoming connections and only close connections that have completed GetCertificate and therefore are finished with their TLS handshake, so that only old established connections are closed.
435 lines
8.7 KiB
Go
435 lines
8.7 KiB
Go
package dynamiclistener
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/rancher/dynamiclistener/cert"
|
|
"github.com/rancher/dynamiclistener/factory"
|
|
"github.com/sirupsen/logrus"
|
|
v1 "k8s.io/api/core/v1"
|
|
)
|
|
|
|
type TLSStorage interface {
|
|
Get() (*v1.Secret, error)
|
|
Update(secret *v1.Secret) error
|
|
}
|
|
|
|
type TLSFactory interface {
|
|
Renew(secret *v1.Secret) (*v1.Secret, error)
|
|
AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error)
|
|
Merge(target *v1.Secret, additional *v1.Secret) (*v1.Secret, bool, error)
|
|
Filter(cn ...string) []string
|
|
}
|
|
|
|
type SetFactory interface {
|
|
SetFactory(tls TLSFactory)
|
|
}
|
|
|
|
func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, caKey crypto.Signer, config Config) (net.Listener, http.Handler, error) {
|
|
if config.CN == "" {
|
|
config.CN = "dynamic"
|
|
}
|
|
if len(config.Organization) == 0 {
|
|
config.Organization = []string{"dynamic"}
|
|
}
|
|
if config.TLSConfig == nil {
|
|
config.TLSConfig = &tls.Config{}
|
|
}
|
|
|
|
dynamicListener := &listener{
|
|
factory: &factory.TLS{
|
|
CACert: caCert,
|
|
CAKey: caKey,
|
|
CN: config.CN,
|
|
Organization: config.Organization,
|
|
FilterCN: allowDefaultSANs(config.SANs, config.FilterCN),
|
|
},
|
|
Listener: l,
|
|
storage: &nonNil{storage: storage},
|
|
sans: config.SANs,
|
|
maxSANs: config.MaxSANs,
|
|
tlsConfig: config.TLSConfig,
|
|
}
|
|
if dynamicListener.tlsConfig == nil {
|
|
dynamicListener.tlsConfig = &tls.Config{}
|
|
}
|
|
dynamicListener.tlsConfig.GetCertificate = dynamicListener.getCertificate
|
|
|
|
if config.CloseConnOnCertChange {
|
|
if len(dynamicListener.tlsConfig.Certificates) == 0 {
|
|
dynamicListener.tlsConfig.NextProtos = []string{"http/1.1"}
|
|
}
|
|
dynamicListener.conns = map[int]*closeWrapper{}
|
|
}
|
|
|
|
if setter, ok := storage.(SetFactory); ok {
|
|
setter.SetFactory(dynamicListener.factory)
|
|
}
|
|
|
|
if config.ExpirationDaysCheck == 0 {
|
|
config.ExpirationDaysCheck = 30
|
|
}
|
|
|
|
tlsListener := tls.NewListener(dynamicListener.WrapExpiration(config.ExpirationDaysCheck), dynamicListener.tlsConfig)
|
|
return tlsListener, dynamicListener.cacheHandler(), nil
|
|
}
|
|
|
|
func allowDefaultSANs(sans []string, next func(...string) []string) func(...string) []string {
|
|
if next == nil {
|
|
return nil
|
|
} else if len(sans) == 0 {
|
|
return next
|
|
}
|
|
|
|
sanMap := map[string]bool{}
|
|
for _, san := range sans {
|
|
sanMap[san] = true
|
|
}
|
|
|
|
return func(s ...string) []string {
|
|
var (
|
|
good []string
|
|
unknown []string
|
|
)
|
|
for _, s := range s {
|
|
if sanMap[s] {
|
|
good = append(good, s)
|
|
} else {
|
|
unknown = append(unknown, s)
|
|
}
|
|
}
|
|
|
|
return append(good, next(unknown...)...)
|
|
}
|
|
}
|
|
|
|
type cancelClose struct {
|
|
cancel func()
|
|
net.Listener
|
|
}
|
|
|
|
func (c *cancelClose) Close() error {
|
|
c.cancel()
|
|
return c.Listener.Close()
|
|
}
|
|
|
|
type Config struct {
|
|
CN string
|
|
Organization []string
|
|
TLSConfig *tls.Config
|
|
SANs []string
|
|
MaxSANs int
|
|
ExpirationDaysCheck int
|
|
CloseConnOnCertChange bool
|
|
FilterCN func(...string) []string
|
|
}
|
|
|
|
type listener struct {
|
|
sync.RWMutex
|
|
net.Listener
|
|
|
|
conns map[int]*closeWrapper
|
|
connID int
|
|
connLock sync.Mutex
|
|
|
|
factory TLSFactory
|
|
storage TLSStorage
|
|
version string
|
|
tlsConfig *tls.Config
|
|
cert *tls.Certificate
|
|
sans []string
|
|
maxSANs int
|
|
init sync.Once
|
|
}
|
|
|
|
func (l *listener) WrapExpiration(days int) net.Listener {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() {
|
|
time.Sleep(30 * time.Second)
|
|
|
|
for {
|
|
wait := 6 * time.Hour
|
|
if err := l.checkExpiration(days); err != nil {
|
|
logrus.Errorf("failed to check and renew dynamic cert: %v", err)
|
|
// Don't go into short retry loop if we're using a static (user-provided) cert.
|
|
// We will still check and print an error every six hours until the user updates the secret with
|
|
// a cert that is not about to expire. Hopefully this will prompt them to take action.
|
|
if err != cert.ErrStaticCert {
|
|
wait = 5 * time.Minute
|
|
}
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(wait):
|
|
}
|
|
}
|
|
}()
|
|
|
|
return &cancelClose{
|
|
cancel: cancel,
|
|
Listener: l,
|
|
}
|
|
}
|
|
|
|
func (l *listener) checkExpiration(days int) error {
|
|
l.Lock()
|
|
defer l.Unlock()
|
|
|
|
if days == 0 {
|
|
return nil
|
|
}
|
|
|
|
if l.cert == nil {
|
|
return nil
|
|
}
|
|
|
|
secret, err := l.storage.Get()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
keyPair, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
certParsed, err := x509.ParseCertificate(keyPair.Certificate[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if cert.IsCertExpired(certParsed, days) {
|
|
secret, err := l.factory.Renew(secret)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := l.storage.Update(secret); err != nil {
|
|
return err
|
|
}
|
|
// clear version to force cert reload
|
|
l.version = ""
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (l *listener) Accept() (net.Conn, error) {
|
|
l.init.Do(func() {
|
|
if len(l.sans) > 0 {
|
|
l.updateCert(l.sans...)
|
|
}
|
|
})
|
|
|
|
conn, err := l.Listener.Accept()
|
|
if err != nil {
|
|
return conn, err
|
|
}
|
|
|
|
addr := conn.LocalAddr()
|
|
if addr == nil {
|
|
return conn, nil
|
|
}
|
|
|
|
host, _, err := net.SplitHostPort(addr.String())
|
|
if err != nil {
|
|
logrus.Errorf("failed to parse network %s: %v", addr.Network(), err)
|
|
return conn, nil
|
|
}
|
|
|
|
if !strings.Contains(host, ":") {
|
|
if err := l.updateCert(host); err != nil {
|
|
logrus.Infof("failed to create TLS cert for: %s, %v", host, err)
|
|
}
|
|
}
|
|
|
|
if l.conns != nil {
|
|
conn = l.wrap(conn)
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (l *listener) wrap(conn net.Conn) net.Conn {
|
|
l.connLock.Lock()
|
|
defer l.connLock.Unlock()
|
|
l.connID++
|
|
|
|
wrapper := &closeWrapper{
|
|
Conn: conn,
|
|
id: l.connID,
|
|
l: l,
|
|
}
|
|
l.conns[l.connID] = wrapper
|
|
|
|
return wrapper
|
|
}
|
|
|
|
type closeWrapper struct {
|
|
net.Conn
|
|
id int
|
|
l *listener
|
|
ready bool
|
|
}
|
|
|
|
func (c *closeWrapper) close() error {
|
|
delete(c.l.conns, c.id)
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func (c *closeWrapper) Close() error {
|
|
c.l.connLock.Lock()
|
|
defer c.l.connLock.Unlock()
|
|
return c.close()
|
|
}
|
|
|
|
func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
newConn := hello.Conn
|
|
if hello.ServerName != "" {
|
|
if err := l.updateCert(hello.ServerName); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return l.loadCert(newConn.(*closeWrapper))
|
|
}
|
|
|
|
func (l *listener) updateCert(cn ...string) error {
|
|
cn = l.factory.Filter(cn...)
|
|
if len(cn) == 0 {
|
|
return nil
|
|
}
|
|
|
|
l.RLock()
|
|
defer l.RUnlock()
|
|
|
|
secret, err := l.storage.Get()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !factory.IsStatic(secret) && !factory.NeedsUpdate(l.maxSANs, secret, cn...) {
|
|
return nil
|
|
}
|
|
|
|
l.RUnlock()
|
|
l.Lock()
|
|
defer l.RLock()
|
|
defer l.Unlock()
|
|
|
|
secret, updated, err := l.factory.AddCN(secret, append(l.sans, cn...)...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if updated {
|
|
if err := l.storage.Update(secret); err != nil {
|
|
return err
|
|
}
|
|
// clear version to force cert reload
|
|
l.version = ""
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (l *listener) loadCert(currentConn *closeWrapper) (*tls.Certificate, error) {
|
|
l.RLock()
|
|
defer l.RUnlock()
|
|
|
|
secret, err := l.storage.Get()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if l.cert != nil && l.version == secret.ResourceVersion && secret.ResourceVersion != "" {
|
|
return l.cert, nil
|
|
}
|
|
|
|
defer l.RLock()
|
|
l.RUnlock()
|
|
l.Lock()
|
|
defer l.Unlock()
|
|
|
|
secret, err = l.storage.Get()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if l.cert != nil && l.version == secret.ResourceVersion && secret.ResourceVersion != "" {
|
|
return l.cert, nil
|
|
}
|
|
|
|
cert, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// cert has changed, close closeWrapper wrapped connections if this isn't the first load
|
|
if l.conns != nil && l.cert != nil {
|
|
l.connLock.Lock()
|
|
for _, conn := range l.conns {
|
|
// Don't close a connection that's in the middle of completing a TLS handshake
|
|
if !conn.ready {
|
|
continue
|
|
}
|
|
_ = conn.close()
|
|
}
|
|
l.conns[currentConn.id].ready = true
|
|
l.connLock.Unlock()
|
|
}
|
|
|
|
l.cert = &cert
|
|
l.version = secret.ResourceVersion
|
|
return l.cert, nil
|
|
}
|
|
|
|
func (l *listener) cacheHandler() http.Handler {
|
|
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
|
h, _, err := net.SplitHostPort(req.Host)
|
|
if err != nil {
|
|
h = req.Host
|
|
}
|
|
|
|
ip := net.ParseIP(h)
|
|
if len(ip) > 0 {
|
|
for _, v := range req.Header["User-Agent"] {
|
|
if strings.Contains(strings.ToLower(v), "mozilla") {
|
|
return
|
|
}
|
|
}
|
|
|
|
l.updateCert(h)
|
|
}
|
|
})
|
|
}
|
|
|
|
type nonNil struct {
|
|
sync.Mutex
|
|
storage TLSStorage
|
|
}
|
|
|
|
func (n *nonNil) Get() (*v1.Secret, error) {
|
|
n.Lock()
|
|
defer n.Unlock()
|
|
|
|
s, err := n.storage.Get()
|
|
if err != nil || s == nil {
|
|
return &v1.Secret{}, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (n *nonNil) Update(secret *v1.Secret) error {
|
|
n.Lock()
|
|
defer n.Unlock()
|
|
|
|
return n.storage.Update(secret)
|
|
}
|