diff --git a/factory/gen.go b/factory/gen.go index 2341748..b54e9a0 100644 --- a/factory/gen.go +++ b/factory/gen.go @@ -29,16 +29,20 @@ type TLS struct { Organization []string } -func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, err error) { - var ( - cns []string - digest = sha256.New() - ) +func cns(secret *v1.Secret) (cns []string) { for k, v := range secret.Annotations { if strings.HasPrefix(k, cnPrefix) { cns = append(cns, v) } } + return +} + +func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, err error) { + var ( + cns = cns(secret) + digest = sha256.New() + ) sort.Strings(cns) @@ -56,6 +60,10 @@ func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, return } +func (t *TLS) Merge(secret, other *v1.Secret) (*v1.Secret, bool, error) { + return t.AddCN(secret, cns(other)...) +} + func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) { var ( err error diff --git a/listener.go b/listener.go index eaec9b5..68cb8e9 100644 --- a/listener.go +++ b/listener.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "net" "net/http" + "strings" "sync" "github.com/rancher/dynamiclistener/factory" @@ -18,6 +19,10 @@ type TLSStorage interface { Update(secret *v1.Secret) error } +type SetFactory interface { + SetFactory(tls *factory.TLS) +} + type Config struct { CN string Organization []string @@ -47,6 +52,10 @@ func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, c } dynamicListener.tlsConfig.GetCertificate = dynamicListener.getCertificate + if setter, ok := storage.(SetFactory); ok { + setter.SetFactory(dynamicListener.factory) + } + return tls.NewListener(dynamicListener, &dynamicListener.tlsConfig), dynamicListener.cacheHandler(), nil } @@ -60,9 +69,16 @@ type listener struct { tlsConfig tls.Config cert *tls.Certificate sans []string + init sync.Once } func (l *listener) Accept() (net.Conn, error) { + l.init.Do(func() { + if len(l.sans) > 0 { + l.updateCert(l.sans...) + } + }) + conn, err := l.Listener.Accept() if err != nil { return conn, err @@ -96,7 +112,7 @@ func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, return l.loadCert() } -func (l *listener) updateCert(cn string) error { +func (l *listener) updateCert(cn ...string) error { l.RLock() defer l.RUnlock() @@ -105,7 +121,7 @@ func (l *listener) updateCert(cn string) error { return err } - if !factory.NeedsUpdate(secret, append(l.sans, cn)...) { + if !factory.NeedsUpdate(secret, cn...) { return nil } @@ -114,7 +130,7 @@ func (l *listener) updateCert(cn string) error { defer l.RLock() defer l.Unlock() - secret, updated, err := l.factory.AddCN(secret, append(l.sans, cn)...) + secret, updated, err := l.factory.AddCN(secret, append(l.sans, cn...)...) if err != nil { return err } diff --git a/storage/kubernetes/controller.go b/storage/kubernetes/controller.go index d6c163c..f903662 100644 --- a/storage/kubernetes/controller.go +++ b/storage/kubernetes/controller.go @@ -6,6 +6,7 @@ import ( "time" "github.com/rancher/dynamiclistener" + "github.com/rancher/dynamiclistener/factory" "github.com/rancher/wrangler-api/pkg/generated/controllers/core" v1controller "github.com/rancher/wrangler-api/pkg/generated/controllers/core/v1" "github.com/rancher/wrangler/pkg/start" @@ -54,6 +55,11 @@ type storage struct { storage dynamiclistener.TLSStorage secrets v1controller.SecretClient ctx context.Context + tls *factory.TLS +} + +func (s *storage) SetFactory(tls *factory.TLS) { + s.tls = tls } func (s *storage) init(secrets v1controller.SecretController) { @@ -105,6 +111,12 @@ func (s *storage) saveInK8s(secret *v1.Secret) (*v1.Secret, error) { return secret, nil } + if existing, err := s.storage.Get(); err == nil && s.tls != nil { + if newSecret, updated, err := s.tls.Merge(secret, existing); err == nil && updated { + secret = newSecret + } + } + targetSecret, err := s.targetSecret() if err != nil { return nil, err