mirror of
https://github.com/rancher/dynamiclistener.git
synced 2025-09-12 13:03:50 +00:00
Refactor to not include a server by default
This commit is contained in:
198
listener.go
Normal file
198
listener.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package dynamiclistener
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rancher/dynamiclistener/factory"
|
||||
"github.com/sirupsen/logrus"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
)
|
||||
|
||||
type TLSStorage interface {
|
||||
Get() (*v1.Secret, error)
|
||||
Update(secret *v1.Secret) error
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
CN string
|
||||
Organization []string
|
||||
TLSConfig tls.Config
|
||||
SANs []string
|
||||
}
|
||||
|
||||
func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, caKey crypto.Signer, config Config) (net.Listener, http.Handler, error) {
|
||||
if config.CN == "" {
|
||||
config.CN = "dynamic"
|
||||
}
|
||||
if len(config.Organization) == 0 {
|
||||
config.Organization = []string{"dynamic"}
|
||||
}
|
||||
|
||||
dynamicListener := &listener{
|
||||
factory: &factory.TLS{
|
||||
CACert: caCert,
|
||||
CAKey: caKey,
|
||||
CN: config.CN,
|
||||
Organization: config.Organization,
|
||||
},
|
||||
Listener: l,
|
||||
storage: &nonNil{storage: storage},
|
||||
sans: config.SANs,
|
||||
tlsConfig: config.TLSConfig,
|
||||
}
|
||||
dynamicListener.tlsConfig.GetCertificate = dynamicListener.getCertificate
|
||||
|
||||
return tls.NewListener(dynamicListener, &dynamicListener.tlsConfig), dynamicListener.cacheHandler(), nil
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
sync.RWMutex
|
||||
net.Listener
|
||||
|
||||
factory *factory.TLS
|
||||
storage TLSStorage
|
||||
version string
|
||||
tlsConfig tls.Config
|
||||
cert *tls.Certificate
|
||||
sans []string
|
||||
}
|
||||
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
addr := conn.RemoteAddr()
|
||||
if addr == nil {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to parse network %s: %v", addr.Network(), err)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
return conn, l.updateCert(host)
|
||||
}
|
||||
|
||||
func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if hello.ServerName != "" {
|
||||
if err := l.updateCert(hello.ServerName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return l.loadCert()
|
||||
}
|
||||
|
||||
func (l *listener) updateCert(cn string) error {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
secret, err := l.storage.Get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !factory.NeedsUpdate(secret, append(l.sans, cn)...) {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.RUnlock()
|
||||
l.Lock()
|
||||
defer l.RLock()
|
||||
defer l.Unlock()
|
||||
|
||||
secret, updated, err := l.factory.AddCN(secret, append(l.sans, cn)...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := l.storage.Update(secret); err != nil {
|
||||
return err
|
||||
}
|
||||
// clear version to force cert reload
|
||||
l.version = ""
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *listener) loadCert() (*tls.Certificate, error) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
secret, err := l.storage.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if l.cert != nil && l.version == secret.ResourceVersion {
|
||||
return l.cert, nil
|
||||
}
|
||||
|
||||
defer l.RLock()
|
||||
l.RUnlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
secret, err = l.storage.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if l.cert != nil && l.version == secret.ResourceVersion {
|
||||
return l.cert, nil
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.cert = &cert
|
||||
return l.cert, nil
|
||||
}
|
||||
|
||||
func (l *listener) cacheHandler() http.Handler {
|
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
||||
h, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
h = req.Host
|
||||
}
|
||||
|
||||
ip := net.ParseIP(h)
|
||||
if len(ip) > 0 {
|
||||
l.updateCert(h)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type nonNil struct {
|
||||
sync.Mutex
|
||||
storage TLSStorage
|
||||
}
|
||||
|
||||
func (n *nonNil) Get() (*v1.Secret, error) {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
s, err := n.storage.Get()
|
||||
if err != nil || s == nil {
|
||||
return &v1.Secret{}, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (n *nonNil) Update(secret *v1.Secret) error {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
return n.storage.Update(secret)
|
||||
}
|
Reference in New Issue
Block a user