Merge pull request #3 from erikwilson/single-cert

Refactor to single cert
This commit is contained in:
Darren Shepherd 2019-09-26 13:58:02 -07:00 committed by GitHub
commit f3b73e948e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 167 additions and 328 deletions

View File

@ -26,11 +26,6 @@ func ReadTLSConfig(userConfig *UserConfig) error {
return err
}
userConfig.Mode = "https"
if len(userConfig.Domains) > 0 {
userConfig.Mode = "acme"
}
valid := false
if userConfig.Key != "" && userConfig.Cert != "" {
valid = true

490
server.go
View File

@ -1,37 +1,27 @@
package dynamiclistener
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/md5"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"log"
"net"
"net/http"
"sort"
"reflect"
"strconv"
"strings"
"sync"
"time"
lru "github.com/hashicorp/golang-lru"
cert "github.com/rancher/dynamiclistener/cert"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
)
const (
httpsMode = "https"
acmeMode = "acme"
)
type server struct {
@ -39,28 +29,28 @@ type server struct {
userConfig UserConfig
listenConfigStorage ListenerConfigStorage
certs map[string]*tls.Certificate
ips *lru.Cache
tlsCert *tls.Certificate
ips map[string]bool
domains map[string]bool
cn string
listeners []net.Listener
servers []*http.Server
// dynamic config change on refresh
activeCert *tls.Certificate
activeCA *x509.Certificate
activeCAKey crypto.Signer
activeCAKeyString string
domains map[string]bool
activeCA *x509.Certificate
activeCAKey crypto.Signer
}
func NewServer(listenConfigStorage ListenerConfigStorage, config UserConfig) (ServerInterface, error) {
s := &server{
userConfig: config,
listenConfigStorage: listenConfigStorage,
certs: map[string]*tls.Certificate{},
cn: "cattle",
}
s.ips, _ = lru.New(20)
s.ips = map[string]bool{}
s.domains = map[string]bool{}
if err := s.userConfigure(); err != nil {
return nil, err
@ -81,16 +71,7 @@ func (s *server) CACert() (string, error) {
if s.userConfig.CACerts != "" {
return s.userConfig.CACerts, nil
}
status, err := s.listenConfigStorage.Get()
if err != nil {
return "", err
}
if status.CACert == "" {
return "", fmt.Errorf("ca cert not found")
}
return status.CACert, nil
return "", fmt.Errorf("ca cert not found")
}
func marshalPrivateKey(privateKey crypto.Signer) (string, []byte, error) {
@ -127,78 +108,25 @@ func newPrivateKey() (crypto.Signer, error) {
return caKeyIFace.(crypto.Signer), nil
}
func (s *server) save() {
if s.activeCert != nil {
return
func (s *server) save() (_err error) {
defer func() {
if _err != nil {
logrus.Errorf("Saving cert error: %s", _err)
}
}()
certStr, err := certToString(s.tlsCert)
if err != nil {
return err
}
s.Lock()
defer s.Unlock()
changed := false
cfg, err := s.listenConfigStorage.Get()
if err != nil {
return
return err
}
cfg.GeneratedCerts = map[string]string{s.cn: certStr}
if cfg.GeneratedCerts == nil {
cfg.GeneratedCerts = map[string]string{}
}
if cfg.KnownIPs == nil {
cfg.KnownIPs = map[string]bool{}
}
for key, cert := range s.certs {
certStr, err := certToString(cert)
if err != nil {
continue
}
if cfg.GeneratedCerts[key] != certStr {
cfg.GeneratedCerts[key] = certStr
changed = true
}
}
for _, obj := range s.ips.Keys() {
ip, _ := obj.(string)
if !cfg.KnownIPs[ip] {
cfg.KnownIPs[ip] = true
changed = true
}
}
if cfg.CAKey == "" && s.activeCAKey != nil && s.activeCA != nil {
caCertBuffer := bytes.Buffer{}
if err := pem.Encode(&caCertBuffer, &pem.Block{
Type: cert.CertificateBlockType,
Bytes: s.activeCA.Raw,
}); err != nil {
return
}
caKeyBuffer := bytes.Buffer{}
keyType, keyBytes, err := marshalPrivateKey(s.activeCAKey)
if err != nil {
return
}
if err := pem.Encode(&caKeyBuffer, &pem.Block{
Type: keyType,
Bytes: keyBytes,
}); err != nil {
return
}
cfg.CACert = string(caCertBuffer.Bytes())
cfg.CAKey = string(caKeyBuffer.Bytes())
s.activeCAKeyString = cfg.CAKey
changed = true
}
if changed {
s.listenConfigStorage.Set(cfg)
}
_, err = s.listenConfigStorage.Set(cfg)
return err
}
func (s *server) userConfigure() error {
@ -206,39 +134,42 @@ func (s *server) userConfigure() error {
s.userConfig.HTTPSPort = 8443
}
if s.userConfig.Mode == "" {
if len(s.userConfig.Domains) > 0 {
s.userConfig.Mode = acmeMode
} else {
s.userConfig.Mode = httpsMode
}
}
s.domains = map[string]bool{}
for _, d := range s.userConfig.Domains {
s.domains[d] = true
}
if s.userConfig.Key != "" && s.userConfig.Cert != "" {
cert, err := tls.X509KeyPair([]byte(s.userConfig.Cert), []byte(s.userConfig.Key))
if err != nil {
return err
for _, ip := range s.userConfig.KnownIPs {
if netIP := net.ParseIP(ip); netIP != nil {
s.ips[ip] = true
}
s.activeCert = &cert
s.userConfig.Mode = httpsMode
return s.reload()
}
for _, ip := range s.userConfig.KnownIPs {
netIP := net.ParseIP(ip)
if netIP != nil {
s.ips.Add(ip, netIP)
if bindAddress := net.ParseIP(s.userConfig.BindAddress); bindAddress != nil {
s.ips[s.userConfig.BindAddress] = true
}
if s.activeCA == nil && s.activeCAKey == nil {
if s.userConfig.CACerts != "" && s.userConfig.CAKey != "" {
ca, err := cert.ParseCertsPEM([]byte(s.userConfig.CACerts))
if err != nil {
return err
}
key, err := cert.ParsePrivateKeyPEM([]byte(s.userConfig.CAKey))
if err != nil {
return err
}
s.activeCA = ca[0]
s.activeCAKey = key.(crypto.Signer)
} else {
ca, key, err := genCA()
if err != nil {
return err
}
s.activeCA = ca
s.activeCAKey = key
}
}
bindAddress := net.ParseIP(s.userConfig.BindAddress)
if bindAddress != nil {
s.ips.Add(s.userConfig.BindAddress, bindAddress)
}
return nil
}
@ -259,61 +190,76 @@ func genCA() (*x509.Certificate, crypto.Signer, error) {
return caCert, caKey, nil
}
func (s *server) Update(status *ListenerStatus) error {
func (s *server) Update(status *ListenerStatus) (_err error) {
s.Lock()
defer s.getCertificate(&tls.ClientHelloInfo{ServerName: "localhost"})
if status.CACert != "" && status.CAKey != "" && s.activeCAKeyString != status.CAKey {
cert, err := tls.X509KeyPair([]byte(status.CACert), []byte(status.CAKey))
if err != nil {
s.Unlock()
return err
defer func() {
s.Unlock()
if _err != nil {
logrus.Errorf("Update cert error: %s", _err)
}
s.activeCAKey = cert.PrivateKey.(crypto.Signer)
s.activeCAKeyString = status.CAKey
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
s.Unlock()
return err
if s.tlsCert == nil {
s.getCertificate(&tls.ClientHelloInfo{ServerName: "localhost"})
}
s.activeCA = x509Cert
s.certs = map[string]*tls.Certificate{}
}()
certString := status.GeneratedCerts[s.cn]
tlsCert, err := stringToCert(certString)
if err != nil {
logrus.Errorf("Update cert unable to convert string to cert: %s", err)
s.tlsCert = nil
}
if tlsCert != nil {
s.tlsCert = tlsCert
for i, certBytes := range tlsCert.Certificate {
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
logrus.Errorf("Update cert %d parse error: %s", i, err)
s.tlsCert = nil
break
}
for ipStr := range status.KnownIPs {
ip := net.ParseIP(ipStr)
if len(ip) > 0 {
s.ips.ContainsOrAdd(ipStr, ip)
ips := map[string]bool{}
for _, ip := range cert.IPAddresses {
ips[ip.String()] = true
}
domains := map[string]bool{}
for _, domain := range cert.DNSNames {
domains[domain] = true
}
if !(reflect.DeepEqual(ips, s.ips) && reflect.DeepEqual(domains, s.domains)) {
subset := true
for ip := range s.ips {
if !ips[ip] {
subset = false
break
}
}
if subset {
for domain := range s.domains {
if !domains[domain] {
subset = false
break
}
}
}
if !subset {
s.tlsCert = nil
}
for ip := range ips {
s.ips[ip] = true
}
for domain := range domains {
s.domains[domain] = true
}
}
}
}
for key, certString := range status.GeneratedCerts {
cert := stringToCert(certString)
if cert != nil {
s.certs[key] = cert
}
}
s.Unlock()
return s.reload()
}
func (s *server) hostPolicy(ctx context.Context, host string) error {
s.Lock()
defer s.Unlock()
if s.domains[host] {
return nil
}
return errors.New("acme/autocert: host not configured")
}
func (s *server) prompt(tos string) bool {
return true
}
func (s *server) shutdown() error {
for _, listener := range s.listeners {
if err := listener.Close(); err != nil {
@ -339,114 +285,53 @@ func (s *server) reload() error {
return err
}
switch s.userConfig.Mode {
case acmeMode:
if err := s.serveACME(); err != nil {
return err
}
case httpsMode:
if err := s.serveHTTPS(); err != nil {
return err
}
if err := s.serveHTTPS(); err != nil {
return err
}
return nil
}
func (s *server) ipMapKey() string {
len := s.ips.Len()
keys := s.ips.Keys()
if len == 0 {
return fmt.Sprintf("local/%d", len)
} else if len == 1 {
return fmt.Sprintf("local/%s", keys[0])
}
sort.Slice(keys, func(i, j int) bool {
l, _ := keys[i].(string)
r, _ := keys[j].(string)
return l < r
})
if len < 6 {
return fmt.Sprintf("local/%v", keys)
}
digest := md5.New()
for _, k := range keys {
s, _ := k.(string)
digest.Write([]byte(s))
}
return fmt.Sprintf("local/%v", hex.EncodeToString(digest.Sum(nil)))
}
func (s *server) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
func (s *server) getCertificate(hello *tls.ClientHelloInfo) (_servingCert *tls.Certificate, _err error) {
s.Lock()
if s.activeCert != nil {
s.Unlock()
return s.activeCert, nil
}
changed := false
defer func() {
defer s.Unlock()
if _err != nil {
logrus.Errorf("Get certificate error: %s", _err)
return
}
if changed {
s.save()
}
}()
defer s.Unlock()
mapKey := hello.ServerName
cn := hello.ServerName
dnsNames := []string{cn}
ipBased := false
var ips []net.IP
if cn == "" {
mapKey = s.ipMapKey()
ipBased = true
if hello.ServerName != "" && !s.domains[hello.ServerName] {
s.tlsCert = nil
s.domains[hello.ServerName] = true
}
serverNameCert, ok := s.certs[mapKey]
if ok {
return serverNameCert, nil
if s.tlsCert != nil {
return s.tlsCert, nil
}
if ipBased {
cn = "cattle"
for _, ipStr := range s.ips.Keys() {
ip := net.ParseIP(ipStr.(string))
if len(ip) > 0 {
ips = append(ips, ip)
}
ips := []net.IP{}
for ipStr := range s.ips {
if ip := net.ParseIP(ipStr); ip != nil {
ips = append(ips, ip)
}
}
changed = true
if s.activeCA == nil {
if s.userConfig.CACerts != "" && s.userConfig.CAKey != "" {
ca, err := cert.ParseCertsPEM([]byte(s.userConfig.CACerts))
if err != nil {
return nil, err
}
key, err := cert.ParsePrivateKeyPEM([]byte(s.userConfig.CAKey))
if err != nil {
return nil, err
}
s.activeCA = ca[0]
s.activeCAKey = key.(crypto.Signer)
} else {
ca, key, err := genCA()
if err != nil {
return nil, err
}
s.activeCA = ca
s.activeCAKey = key
}
dnsNames := []string{}
for domain := range s.domains {
dnsNames = append(dnsNames, domain)
}
cfg := cert.Config{
CommonName: cn,
CommonName: s.cn,
Organization: s.activeCA.Subject.Organization,
Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
AltNames: cert.AltNames{
@ -472,23 +357,31 @@ func (s *server) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, e
PrivateKey: key,
}
s.certs[mapKey] = tlsCert
changed = true
s.tlsCert = tlsCert
return tlsCert, nil
}
func (s *server) cacheIPHandler(handler http.Handler) http.Handler {
func (s *server) cacheHandler(handler http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
h, _, err := net.SplitHostPort(req.Host)
if err != nil {
h = req.Host
}
ip := net.ParseIP(h)
if len(ip) > 0 {
if ok, _ := s.ips.ContainsOrAdd(h, ip); ok {
go s.save()
s.Lock()
if ip := net.ParseIP(h); ip != nil {
if !s.ips[h] {
s.ips[h] = true
s.tlsCert = nil
}
} else {
if !s.domains[h] {
s.domains[h] = true
s.tlsCert = nil
}
}
s.Unlock()
handler.ServeHTTP(resp, req)
})
@ -508,7 +401,7 @@ func (s *server) serveHTTPS() error {
logger := logrus.StandardLogger()
server := &http.Server{
Handler: s.cacheIPHandler(s.Handler()),
Handler: s.cacheHandler(s.Handler()),
ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags),
}
@ -522,7 +415,7 @@ func (s *server) serveHTTPS() error {
}
httpServer := &http.Server{
Handler: s.cacheIPHandler(httpRedirect(s.Handler())),
Handler: s.cacheHandler(httpRedirect(s.Handler())),
ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags),
}
@ -598,97 +491,48 @@ func (s *server) newListener(ip string, port int, config *tls.Config) (net.Liste
return l, nil
}
func (s *server) serveACME() error {
manager := autocert.Manager{
Cache: autocert.DirCache("certs-cache"),
Prompt: s.prompt,
HostPolicy: s.hostPolicy,
}
conf := &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if hello.ServerName == "localhost" || hello.ServerName == "" {
newHello := *hello
newHello.ServerName = s.userConfig.Domains[0]
return manager.GetCertificate(&newHello)
}
return manager.GetCertificate(hello)
},
NextProtos: []string{"h2", "http/1.1"},
}
if s.userConfig.HTTPPort > 0 {
httpListener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPPort, nil)
if err != nil {
return err
}
httpServer := &http.Server{
Handler: manager.HTTPHandler(nil),
ErrorLog: log.New(logrus.StandardLogger().Writer(), "", log.LstdFlags),
}
s.servers = append(s.servers, httpServer)
go func() {
if err := httpServer.Serve(httpListener); err != nil {
logrus.Errorf("http server returned err: %v", err)
}
}()
}
httpsListener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPSPort, conf)
if err != nil {
return err
}
httpsServer := &http.Server{
Handler: s.Handler(),
ErrorLog: log.New(logrus.StandardLogger().Writer(), "", log.LstdFlags),
}
s.servers = append(s.servers, httpsServer)
go func() {
if err := httpsServer.Serve(httpsListener); err != nil {
logrus.Errorf("https server returned err: %v", err)
}
}()
return nil
}
func stringToCert(certString string) *tls.Certificate {
func stringToCert(certString string) (*tls.Certificate, error) {
parts := strings.Split(certString, "#")
if len(parts) != 2 {
return nil
return nil, errors.New("Unable to split cert into two parts")
}
certPart, keyPart := parts[0], parts[1]
keyBytes, err := base64.StdEncoding.DecodeString(keyPart)
if err != nil {
return nil
return nil, err
}
key, err := cert.ParsePrivateKeyPEM(keyBytes)
if err != nil {
return nil
return nil, err
}
certBytes, err := base64.StdEncoding.DecodeString(certPart)
if err != nil {
return nil
return nil, err
}
return &tls.Certificate{
Certificate: [][]byte{certBytes},
PrivateKey: key,
}
}, nil
}
func certToString(cert *tls.Certificate) (string, error) {
_, keyBytes, err := marshalPrivateKey(cert.PrivateKey.(crypto.Signer))
keyType, keyBytes, err := marshalPrivateKey(cert.PrivateKey.(crypto.Signer))
if err != nil {
return "", err
}
privateKeyPemBlock := &pem.Block{
Type: keyType,
Bytes: keyBytes,
}
pemBytes := pem.EncodeToMemory(privateKeyPemBlock)
certString := base64.StdEncoding.EncodeToString(cert.Certificate[0])
keyString := base64.StdEncoding.EncodeToString(keyBytes)
keyString := base64.StdEncoding.EncodeToString(pemBytes)
return certString + "#" + keyString, nil
}