From c08b499d17195fbc2c1764b21c322951811629a5 Mon Sep 17 00:00:00 2001 From: Erik Wilson Date: Wed, 17 Jul 2019 09:46:34 -0700 Subject: [PATCH] Refactor to single cert --- read.go | 5 - server.go | 490 +++++++++++++++++++----------------------------------- 2 files changed, 167 insertions(+), 328 deletions(-) diff --git a/read.go b/read.go index b078703..bf99514 100644 --- a/read.go +++ b/read.go @@ -26,11 +26,6 @@ func ReadTLSConfig(userConfig *UserConfig) error { return err } - userConfig.Mode = "https" - if len(userConfig.Domains) > 0 { - userConfig.Mode = "acme" - } - valid := false if userConfig.Key != "" && userConfig.Cert != "" { valid = true diff --git a/server.go b/server.go index c839da5..ca43086 100644 --- a/server.go +++ b/server.go @@ -1,37 +1,27 @@ package dynamiclistener import ( - "bytes" "context" "crypto" "crypto/ecdsa" - "crypto/md5" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/base64" - "encoding/hex" "encoding/pem" "errors" "fmt" "log" "net" "net/http" - "sort" + "reflect" "strconv" "strings" "sync" "time" - lru "github.com/hashicorp/golang-lru" cert "github.com/rancher/dynamiclistener/cert" "github.com/sirupsen/logrus" - "golang.org/x/crypto/acme/autocert" -) - -const ( - httpsMode = "https" - acmeMode = "acme" ) type server struct { @@ -39,28 +29,28 @@ type server struct { userConfig UserConfig listenConfigStorage ListenerConfigStorage - certs map[string]*tls.Certificate - ips *lru.Cache + tlsCert *tls.Certificate + + ips map[string]bool + domains map[string]bool + cn string listeners []net.Listener servers []*http.Server - // dynamic config change on refresh - activeCert *tls.Certificate - activeCA *x509.Certificate - activeCAKey crypto.Signer - activeCAKeyString string - domains map[string]bool + activeCA *x509.Certificate + activeCAKey crypto.Signer } func NewServer(listenConfigStorage ListenerConfigStorage, config UserConfig) (ServerInterface, error) { s := &server{ userConfig: config, listenConfigStorage: listenConfigStorage, - certs: map[string]*tls.Certificate{}, + cn: "cattle", } - s.ips, _ = lru.New(20) + s.ips = map[string]bool{} + s.domains = map[string]bool{} if err := s.userConfigure(); err != nil { return nil, err @@ -81,16 +71,7 @@ func (s *server) CACert() (string, error) { if s.userConfig.CACerts != "" { return s.userConfig.CACerts, nil } - status, err := s.listenConfigStorage.Get() - if err != nil { - return "", err - } - - if status.CACert == "" { - return "", fmt.Errorf("ca cert not found") - } - - return status.CACert, nil + return "", fmt.Errorf("ca cert not found") } func marshalPrivateKey(privateKey crypto.Signer) (string, []byte, error) { @@ -127,78 +108,25 @@ func newPrivateKey() (crypto.Signer, error) { return caKeyIFace.(crypto.Signer), nil } -func (s *server) save() { - if s.activeCert != nil { - return +func (s *server) save() (_err error) { + defer func() { + if _err != nil { + logrus.Errorf("Saving cert error: %s", _err) + } + }() + + certStr, err := certToString(s.tlsCert) + if err != nil { + return err } - - s.Lock() - defer s.Unlock() - - changed := false cfg, err := s.listenConfigStorage.Get() if err != nil { - return + return err } + cfg.GeneratedCerts = map[string]string{s.cn: certStr} - if cfg.GeneratedCerts == nil { - cfg.GeneratedCerts = map[string]string{} - } - - if cfg.KnownIPs == nil { - cfg.KnownIPs = map[string]bool{} - } - - for key, cert := range s.certs { - certStr, err := certToString(cert) - if err != nil { - continue - } - if cfg.GeneratedCerts[key] != certStr { - cfg.GeneratedCerts[key] = certStr - changed = true - } - } - - for _, obj := range s.ips.Keys() { - ip, _ := obj.(string) - if !cfg.KnownIPs[ip] { - cfg.KnownIPs[ip] = true - changed = true - } - } - - if cfg.CAKey == "" && s.activeCAKey != nil && s.activeCA != nil { - caCertBuffer := bytes.Buffer{} - if err := pem.Encode(&caCertBuffer, &pem.Block{ - Type: cert.CertificateBlockType, - Bytes: s.activeCA.Raw, - }); err != nil { - return - } - - caKeyBuffer := bytes.Buffer{} - keyType, keyBytes, err := marshalPrivateKey(s.activeCAKey) - if err != nil { - return - } - - if err := pem.Encode(&caKeyBuffer, &pem.Block{ - Type: keyType, - Bytes: keyBytes, - }); err != nil { - return - } - - cfg.CACert = string(caCertBuffer.Bytes()) - cfg.CAKey = string(caKeyBuffer.Bytes()) - s.activeCAKeyString = cfg.CAKey - changed = true - } - - if changed { - s.listenConfigStorage.Set(cfg) - } + _, err = s.listenConfigStorage.Set(cfg) + return err } func (s *server) userConfigure() error { @@ -206,39 +134,42 @@ func (s *server) userConfigure() error { s.userConfig.HTTPSPort = 8443 } - if s.userConfig.Mode == "" { - if len(s.userConfig.Domains) > 0 { - s.userConfig.Mode = acmeMode - } else { - s.userConfig.Mode = httpsMode - } - } - - s.domains = map[string]bool{} for _, d := range s.userConfig.Domains { s.domains[d] = true } - if s.userConfig.Key != "" && s.userConfig.Cert != "" { - cert, err := tls.X509KeyPair([]byte(s.userConfig.Cert), []byte(s.userConfig.Key)) - if err != nil { - return err + for _, ip := range s.userConfig.KnownIPs { + if netIP := net.ParseIP(ip); netIP != nil { + s.ips[ip] = true } - s.activeCert = &cert - s.userConfig.Mode = httpsMode - return s.reload() } - for _, ip := range s.userConfig.KnownIPs { - netIP := net.ParseIP(ip) - if netIP != nil { - s.ips.Add(ip, netIP) + if bindAddress := net.ParseIP(s.userConfig.BindAddress); bindAddress != nil { + s.ips[s.userConfig.BindAddress] = true + } + + if s.activeCA == nil && s.activeCAKey == nil { + if s.userConfig.CACerts != "" && s.userConfig.CAKey != "" { + ca, err := cert.ParseCertsPEM([]byte(s.userConfig.CACerts)) + if err != nil { + return err + } + key, err := cert.ParsePrivateKeyPEM([]byte(s.userConfig.CAKey)) + if err != nil { + return err + } + s.activeCA = ca[0] + s.activeCAKey = key.(crypto.Signer) + } else { + ca, key, err := genCA() + if err != nil { + return err + } + s.activeCA = ca + s.activeCAKey = key } } - bindAddress := net.ParseIP(s.userConfig.BindAddress) - if bindAddress != nil { - s.ips.Add(s.userConfig.BindAddress, bindAddress) - } + return nil } @@ -259,61 +190,76 @@ func genCA() (*x509.Certificate, crypto.Signer, error) { return caCert, caKey, nil } -func (s *server) Update(status *ListenerStatus) error { +func (s *server) Update(status *ListenerStatus) (_err error) { s.Lock() - defer s.getCertificate(&tls.ClientHelloInfo{ServerName: "localhost"}) - - if status.CACert != "" && status.CAKey != "" && s.activeCAKeyString != status.CAKey { - cert, err := tls.X509KeyPair([]byte(status.CACert), []byte(status.CAKey)) - if err != nil { - s.Unlock() - return err + defer func() { + s.Unlock() + if _err != nil { + logrus.Errorf("Update cert error: %s", _err) } - s.activeCAKey = cert.PrivateKey.(crypto.Signer) - s.activeCAKeyString = status.CAKey - - x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - s.Unlock() - return err + if s.tlsCert == nil { + s.getCertificate(&tls.ClientHelloInfo{ServerName: "localhost"}) } - s.activeCA = x509Cert - s.certs = map[string]*tls.Certificate{} + }() + + certString := status.GeneratedCerts[s.cn] + tlsCert, err := stringToCert(certString) + if err != nil { + logrus.Errorf("Update cert unable to convert string to cert: %s", err) + s.tlsCert = nil } + if tlsCert != nil { + s.tlsCert = tlsCert + for i, certBytes := range tlsCert.Certificate { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + logrus.Errorf("Update cert %d parse error: %s", i, err) + s.tlsCert = nil + break + } - for ipStr := range status.KnownIPs { - ip := net.ParseIP(ipStr) - if len(ip) > 0 { - s.ips.ContainsOrAdd(ipStr, ip) + ips := map[string]bool{} + for _, ip := range cert.IPAddresses { + ips[ip.String()] = true + } + + domains := map[string]bool{} + for _, domain := range cert.DNSNames { + domains[domain] = true + } + + if !(reflect.DeepEqual(ips, s.ips) && reflect.DeepEqual(domains, s.domains)) { + subset := true + for ip := range s.ips { + if !ips[ip] { + subset = false + break + } + } + if subset { + for domain := range s.domains { + if !domains[domain] { + subset = false + break + } + } + } + if !subset { + s.tlsCert = nil + } + for ip := range ips { + s.ips[ip] = true + } + for domain := range domains { + s.domains[domain] = true + } + } } } - for key, certString := range status.GeneratedCerts { - cert := stringToCert(certString) - if cert != nil { - s.certs[key] = cert - } - } - - s.Unlock() return s.reload() } -func (s *server) hostPolicy(ctx context.Context, host string) error { - s.Lock() - defer s.Unlock() - - if s.domains[host] { - return nil - } - - return errors.New("acme/autocert: host not configured") -} - -func (s *server) prompt(tos string) bool { - return true -} - func (s *server) shutdown() error { for _, listener := range s.listeners { if err := listener.Close(); err != nil { @@ -339,114 +285,53 @@ func (s *server) reload() error { return err } - switch s.userConfig.Mode { - case acmeMode: - if err := s.serveACME(); err != nil { - return err - } - case httpsMode: - if err := s.serveHTTPS(); err != nil { - return err - } + if err := s.serveHTTPS(); err != nil { + return err } return nil } -func (s *server) ipMapKey() string { - len := s.ips.Len() - keys := s.ips.Keys() - if len == 0 { - return fmt.Sprintf("local/%d", len) - } else if len == 1 { - return fmt.Sprintf("local/%s", keys[0]) - } - - sort.Slice(keys, func(i, j int) bool { - l, _ := keys[i].(string) - r, _ := keys[j].(string) - return l < r - }) - if len < 6 { - return fmt.Sprintf("local/%v", keys) - } - - digest := md5.New() - for _, k := range keys { - s, _ := k.(string) - digest.Write([]byte(s)) - } - - return fmt.Sprintf("local/%v", hex.EncodeToString(digest.Sum(nil))) -} - -func (s *server) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (s *server) getCertificate(hello *tls.ClientHelloInfo) (_servingCert *tls.Certificate, _err error) { s.Lock() - if s.activeCert != nil { - s.Unlock() - return s.activeCert, nil - } - changed := false + defer func() { + defer s.Unlock() + + if _err != nil { + logrus.Errorf("Get certificate error: %s", _err) + return + } + if changed { s.save() } }() - defer s.Unlock() - mapKey := hello.ServerName - cn := hello.ServerName - dnsNames := []string{cn} - ipBased := false - var ips []net.IP - - if cn == "" { - mapKey = s.ipMapKey() - ipBased = true + if hello.ServerName != "" && !s.domains[hello.ServerName] { + s.tlsCert = nil + s.domains[hello.ServerName] = true } - serverNameCert, ok := s.certs[mapKey] - if ok { - return serverNameCert, nil + if s.tlsCert != nil { + return s.tlsCert, nil } - if ipBased { - cn = "cattle" - for _, ipStr := range s.ips.Keys() { - ip := net.ParseIP(ipStr.(string)) - if len(ip) > 0 { - ips = append(ips, ip) - } + ips := []net.IP{} + for ipStr := range s.ips { + if ip := net.ParseIP(ipStr); ip != nil { + ips = append(ips, ip) } } - changed = true - - if s.activeCA == nil { - if s.userConfig.CACerts != "" && s.userConfig.CAKey != "" { - ca, err := cert.ParseCertsPEM([]byte(s.userConfig.CACerts)) - if err != nil { - return nil, err - } - key, err := cert.ParsePrivateKeyPEM([]byte(s.userConfig.CAKey)) - if err != nil { - return nil, err - } - s.activeCA = ca[0] - s.activeCAKey = key.(crypto.Signer) - } else { - ca, key, err := genCA() - if err != nil { - return nil, err - } - s.activeCA = ca - s.activeCAKey = key - } + dnsNames := []string{} + for domain := range s.domains { + dnsNames = append(dnsNames, domain) } cfg := cert.Config{ - CommonName: cn, + CommonName: s.cn, Organization: s.activeCA.Subject.Organization, Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, AltNames: cert.AltNames{ @@ -472,23 +357,31 @@ func (s *server) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, e PrivateKey: key, } - s.certs[mapKey] = tlsCert + changed = true + s.tlsCert = tlsCert return tlsCert, nil } -func (s *server) cacheIPHandler(handler http.Handler) http.Handler { +func (s *server) cacheHandler(handler http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { h, _, err := net.SplitHostPort(req.Host) if err != nil { h = req.Host } - ip := net.ParseIP(h) - if len(ip) > 0 { - if ok, _ := s.ips.ContainsOrAdd(h, ip); ok { - go s.save() + s.Lock() + if ip := net.ParseIP(h); ip != nil { + if !s.ips[h] { + s.ips[h] = true + s.tlsCert = nil + } + } else { + if !s.domains[h] { + s.domains[h] = true + s.tlsCert = nil } } + s.Unlock() handler.ServeHTTP(resp, req) }) @@ -508,7 +401,7 @@ func (s *server) serveHTTPS() error { logger := logrus.StandardLogger() server := &http.Server{ - Handler: s.cacheIPHandler(s.Handler()), + Handler: s.cacheHandler(s.Handler()), ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags), } @@ -522,7 +415,7 @@ func (s *server) serveHTTPS() error { } httpServer := &http.Server{ - Handler: s.cacheIPHandler(httpRedirect(s.Handler())), + Handler: s.cacheHandler(httpRedirect(s.Handler())), ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags), } @@ -598,97 +491,48 @@ func (s *server) newListener(ip string, port int, config *tls.Config) (net.Liste return l, nil } -func (s *server) serveACME() error { - manager := autocert.Manager{ - Cache: autocert.DirCache("certs-cache"), - Prompt: s.prompt, - HostPolicy: s.hostPolicy, - } - conf := &tls.Config{ - GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - if hello.ServerName == "localhost" || hello.ServerName == "" { - newHello := *hello - newHello.ServerName = s.userConfig.Domains[0] - return manager.GetCertificate(&newHello) - } - return manager.GetCertificate(hello) - }, - NextProtos: []string{"h2", "http/1.1"}, - } - - if s.userConfig.HTTPPort > 0 { - httpListener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPPort, nil) - if err != nil { - return err - } - - httpServer := &http.Server{ - Handler: manager.HTTPHandler(nil), - ErrorLog: log.New(logrus.StandardLogger().Writer(), "", log.LstdFlags), - } - s.servers = append(s.servers, httpServer) - go func() { - if err := httpServer.Serve(httpListener); err != nil { - logrus.Errorf("http server returned err: %v", err) - } - }() - - } - - httpsListener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPSPort, conf) - if err != nil { - return err - } - - httpsServer := &http.Server{ - Handler: s.Handler(), - ErrorLog: log.New(logrus.StandardLogger().Writer(), "", log.LstdFlags), - } - s.servers = append(s.servers, httpsServer) - go func() { - if err := httpsServer.Serve(httpsListener); err != nil { - logrus.Errorf("https server returned err: %v", err) - } - }() - - return nil -} - -func stringToCert(certString string) *tls.Certificate { +func stringToCert(certString string) (*tls.Certificate, error) { parts := strings.Split(certString, "#") if len(parts) != 2 { - return nil + return nil, errors.New("Unable to split cert into two parts") } certPart, keyPart := parts[0], parts[1] keyBytes, err := base64.StdEncoding.DecodeString(keyPart) if err != nil { - return nil + return nil, err } key, err := cert.ParsePrivateKeyPEM(keyBytes) if err != nil { - return nil + return nil, err } certBytes, err := base64.StdEncoding.DecodeString(certPart) if err != nil { - return nil + return nil, err } return &tls.Certificate{ Certificate: [][]byte{certBytes}, PrivateKey: key, - } + }, nil } func certToString(cert *tls.Certificate) (string, error) { - _, keyBytes, err := marshalPrivateKey(cert.PrivateKey.(crypto.Signer)) + keyType, keyBytes, err := marshalPrivateKey(cert.PrivateKey.(crypto.Signer)) if err != nil { return "", err } + + privateKeyPemBlock := &pem.Block{ + Type: keyType, + Bytes: keyBytes, + } + pemBytes := pem.EncodeToMemory(privateKeyPemBlock) + certString := base64.StdEncoding.EncodeToString(cert.Certificate[0]) - keyString := base64.StdEncoding.EncodeToString(keyBytes) + keyString := base64.StdEncoding.EncodeToString(pemBytes) return certString + "#" + keyString, nil }