From af0486784331e646b726b091edf061d8b15661bf Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Sat, 26 Oct 2019 23:18:58 -0700 Subject: [PATCH] Refactor to not include a server by default --- factory/ca.go | 80 +++++ factory/cert_utils.go | 105 ++++++ factory/gen.go | 164 +++++++++ go.mod | 13 +- listener.go | 198 +++++++++++ read.go | 50 --- redirect.go | 46 +++ server.go | 556 ------------------------------- server/server.go | 80 +++++ storage/file/file.go | 42 +++ storage/kubernetes/controller.go | 109 ++++++ storage/memory/memory.go | 42 +++ tcp.go | 38 +++ types.go | 63 ---- 14 files changed, 906 insertions(+), 680 deletions(-) create mode 100644 factory/ca.go create mode 100644 factory/cert_utils.go create mode 100644 factory/gen.go create mode 100644 listener.go delete mode 100644 read.go create mode 100644 redirect.go delete mode 100644 server.go create mode 100644 server/server.go create mode 100644 storage/file/file.go create mode 100644 storage/kubernetes/controller.go create mode 100644 storage/memory/memory.go create mode 100644 tcp.go delete mode 100644 types.go diff --git a/factory/ca.go b/factory/ca.go new file mode 100644 index 0000000..a35738c --- /dev/null +++ b/factory/ca.go @@ -0,0 +1,80 @@ +package factory + +import ( + "crypto/ecdsa" + "crypto/x509" + "io/ioutil" + "os" +) + +func GenCA() (*x509.Certificate, *ecdsa.PrivateKey, error) { + caKey, err := NewPrivateKey() + if err != nil { + return nil, nil, err + } + + caCert, err := NewSelfSignedCACert(caKey, "dynamiclistener-ca", "dynamiclistener-org") + if err != nil { + return nil, nil, err + } + + return caCert, caKey, nil +} + +func LoadOrGenCA() (*x509.Certificate, *ecdsa.PrivateKey, error) { + cert, key, err := loadCA() + if err == nil { + return cert, key, nil + } + + cert, key, err = GenCA() + if err != nil { + return nil, nil, err + } + + certBytes, keyBytes, err := Marshal(cert, key) + if err != nil { + return nil, nil, err + } + + if err := os.MkdirAll("./certs", 0700); err != nil { + return nil, nil, err + } + + if err := ioutil.WriteFile("./certs/ca.pem", certBytes, 0600); err != nil { + return nil, nil, err + } + + if err := ioutil.WriteFile("./certs/ca.key", keyBytes, 0600); err != nil { + return nil, nil, err + } + + return cert, key, nil +} + +func loadCA() (*x509.Certificate, *ecdsa.PrivateKey, error) { + return LoadCerts("./certs/ca.pem", "./certs/ca.key") +} + +func LoadCerts(certFile, keyFile string) (*x509.Certificate, *ecdsa.PrivateKey, error) { + caPem, err := ioutil.ReadFile(certFile) + if err != nil { + return nil, nil, err + } + caKey, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, nil, err + } + + key, err := ParseECPrivateKeyPEM(caKey) + if err != nil { + return nil, nil, err + } + + cert, err := ParseCertPEM(caPem) + if err != nil { + return nil, nil, err + } + + return cert, key, nil +} diff --git a/factory/cert_utils.go b/factory/cert_utils.go new file mode 100644 index 0000000..459bd2e --- /dev/null +++ b/factory/cert_utils.go @@ -0,0 +1,105 @@ +package factory + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math" + "math/big" + "net" + "time" +) + +const ( + ECPrivateKeyBlockType = "EC PRIVATE KEY" + CertificateBlockType = "CERTIFICATE" +) + +func NewSelfSignedCACert(key crypto.Signer, cn string, org ...string) (*x509.Certificate, error) { + now := time.Now() + tmpl := x509.Certificate{ + BasicConstraintsValid: true, + IsCA: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + NotAfter: now.Add(time.Hour * 24 * 365 * 10).UTC(), + NotBefore: now.UTC(), + SerialNumber: new(big.Int).SetInt64(0), + Subject: pkix.Name{ + CommonName: cn, + Organization: org, + }, + } + + certDERBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, key.Public(), key) + if err != nil { + return nil, err + } + + return x509.ParseCertificate(certDERBytes) +} + +func NewSignedCert(signer crypto.Signer, caCert *x509.Certificate, caKey crypto.Signer, cn string, orgs []string, + domains []string, ips []net.IP) (*x509.Certificate, error) { + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).SetInt64(math.MaxInt64)) + if err != nil { + return nil, err + } + + parent := x509.Certificate{ + DNSNames: domains, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: ips, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + NotAfter: time.Now().Add(time.Hour * 24 * 365).UTC(), + NotBefore: caCert.NotBefore, + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: cn, + Organization: orgs, + }, + } + + cert, err := x509.CreateCertificate(rand.Reader, &parent, caCert, signer.Public(), caKey) + if err != nil { + return nil, err + } + + return x509.ParseCertificate(cert) +} + +func ParseECPrivateKeyPEM(keyData []byte) (*ecdsa.PrivateKey, error) { + var privateKeyPemBlock *pem.Block + for { + privateKeyPemBlock, keyData = pem.Decode(keyData) + if privateKeyPemBlock == nil { + break + } + + if privateKeyPemBlock.Type == ECPrivateKeyBlockType { + return x509.ParseECPrivateKey(privateKeyPemBlock.Bytes) + } + } + + return nil, fmt.Errorf("pem does not include a valid EC private key") +} + +func ParseCertPEM(pemCerts []byte) (*x509.Certificate, error) { + var pemBlock *pem.Block + for { + pemBlock, pemCerts = pem.Decode(pemCerts) + if pemBlock == nil { + break + } + + if pemBlock.Type == CertificateBlockType { + return x509.ParseCertificate(pemBlock.Bytes) + } + } + + return nil, fmt.Errorf("pem does not include a valid x509 cert") +} diff --git a/factory/gen.go b/factory/gen.go new file mode 100644 index 0000000..2341748 --- /dev/null +++ b/factory/gen.go @@ -0,0 +1,164 @@ +package factory + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "net" + "sort" + "strings" + + v1 "k8s.io/api/core/v1" +) + +const ( + cnPrefix = "listener.cattle.io/cn-" + static = "listener.cattle.io/static" + hashKey = "listener.cattle.io/hash" +) + +type TLS struct { + CACert *x509.Certificate + CAKey crypto.Signer + CN string + Organization []string +} + +func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, err error) { + var ( + cns []string + digest = sha256.New() + ) + for k, v := range secret.Annotations { + if strings.HasPrefix(k, cnPrefix) { + cns = append(cns, v) + } + } + + sort.Strings(cns) + + for _, v := range cns { + digest.Write([]byte(v)) + ip := net.ParseIP(v) + if ip == nil { + domains = append(domains, v) + } else { + ips = append(ips, ip) + } + } + + hash = hex.EncodeToString(digest.Sum(nil)) + return +} + +func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) { + var ( + err error + ) + + if !NeedsUpdate(secret, cn...) { + return secret, false, nil + } + + secret = populateCN(secret, cn...) + + privateKey, err := getPrivateKey(secret) + if err != nil { + return nil, false, err + } + + domains, ips, hash, err := collectCNs(secret) + if err != nil { + return nil, false, err + } + + newCert, err := t.newCert(domains, ips, privateKey) + if err != nil { + return nil, false, err + } + + certBytes, keyBytes, err := Marshal(newCert, privateKey) + if err != nil { + return nil, false, err + } + + if secret.Data == nil { + secret.Data = map[string][]byte{} + } + secret.Data[v1.TLSCertKey] = certBytes + secret.Data[v1.TLSPrivateKeyKey] = keyBytes + secret.Annotations[hashKey] = hash + + return secret, true, nil +} + +func (t *TLS) newCert(domains []string, ips []net.IP, privateKey *ecdsa.PrivateKey) (*x509.Certificate, error) { + return NewSignedCert(privateKey, t.CACert, t.CAKey, t.CN, t.Organization, domains, ips) +} + +func populateCN(secret *v1.Secret, cn ...string) *v1.Secret { + secret = secret.DeepCopy() + if secret.Annotations == nil { + secret.Annotations = map[string]string{} + } + for _, cn := range cn { + secret.Annotations[cnPrefix+cn] = cn + } + return secret +} + +func NeedsUpdate(secret *v1.Secret, cn ...string) bool { + if secret.Annotations[static] == "true" { + return false + } + + for _, cn := range cn { + if secret.Annotations[cnPrefix+cn] == "" { + return true + } + } + + return false +} + +func getPrivateKey(secret *v1.Secret) (*ecdsa.PrivateKey, error) { + keyBytes := secret.Data[v1.TLSPrivateKeyKey] + if len(keyBytes) == 0 { + return NewPrivateKey() + } + + privateKey, err := ParseECPrivateKeyPEM(keyBytes) + if err == nil { + return privateKey, nil + } + + return NewPrivateKey() +} + +func Marshal(x509Cert *x509.Certificate, privateKey *ecdsa.PrivateKey) ([]byte, []byte, error) { + certBlock := pem.Block{ + Type: CertificateBlockType, + Bytes: x509Cert.Raw, + } + + keyBytes, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return nil, nil, err + } + + keyBlock := pem.Block{ + Type: ECPrivateKeyBlockType, + Bytes: keyBytes, + } + + return pem.EncodeToMemory(&certBlock), pem.EncodeToMemory(&keyBlock), nil +} + +func NewPrivateKey() (*ecdsa.PrivateKey, error) { + return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) +} diff --git a/go.mod b/go.mod index c1ca544..a2a54ab 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,7 @@ module github.com/rancher/dynamiclistener go 1.12 require ( - github.com/hashicorp/golang-lru v0.5.1 - github.com/kisielk/gotool v1.0.0 // indirect - github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect + github.com/rancher/wrangler-api v0.2.0 github.com/sirupsen/logrus v1.4.1 - github.com/stretchr/testify v1.3.0 // indirect - github.com/stripe/safesql v0.2.0 // indirect - golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284 - golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c // indirect - golang.org/x/sys v0.0.0-20190509141414-a5b02f93d862 // indirect - golang.org/x/text v0.3.2 // indirect - mvdan.cc/interfacer v0.0.0-20180901003855-c20040233aed // indirect - mvdan.cc/lint v0.0.0-20170908181259-adc824a0674b // indirect + k8s.io/api v0.0.0-20190409021203-6e4e0e4f393b ) diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..7ba1a69 --- /dev/null +++ b/listener.go @@ -0,0 +1,198 @@ +package dynamiclistener + +import ( + "crypto" + "crypto/tls" + "crypto/x509" + "net" + "net/http" + "sync" + + "github.com/rancher/dynamiclistener/factory" + "github.com/sirupsen/logrus" + v1 "k8s.io/api/core/v1" +) + +type TLSStorage interface { + Get() (*v1.Secret, error) + Update(secret *v1.Secret) error +} + +type Config struct { + CN string + Organization []string + TLSConfig tls.Config + SANs []string +} + +func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, caKey crypto.Signer, config Config) (net.Listener, http.Handler, error) { + if config.CN == "" { + config.CN = "dynamic" + } + if len(config.Organization) == 0 { + config.Organization = []string{"dynamic"} + } + + dynamicListener := &listener{ + factory: &factory.TLS{ + CACert: caCert, + CAKey: caKey, + CN: config.CN, + Organization: config.Organization, + }, + Listener: l, + storage: &nonNil{storage: storage}, + sans: config.SANs, + tlsConfig: config.TLSConfig, + } + dynamicListener.tlsConfig.GetCertificate = dynamicListener.getCertificate + + return tls.NewListener(dynamicListener, &dynamicListener.tlsConfig), dynamicListener.cacheHandler(), nil +} + +type listener struct { + sync.RWMutex + net.Listener + + factory *factory.TLS + storage TLSStorage + version string + tlsConfig tls.Config + cert *tls.Certificate + sans []string +} + +func (l *listener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return conn, err + } + + addr := conn.RemoteAddr() + if addr == nil { + return conn, nil + } + + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + logrus.Errorf("failed to parse network %s: %v", addr.Network(), err) + return conn, nil + } + + return conn, l.updateCert(host) +} + +func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if hello.ServerName != "" { + if err := l.updateCert(hello.ServerName); err != nil { + return nil, err + } + } + + return l.loadCert() +} + +func (l *listener) updateCert(cn string) error { + l.RLock() + defer l.RUnlock() + + secret, err := l.storage.Get() + if err != nil { + return err + } + + if !factory.NeedsUpdate(secret, append(l.sans, cn)...) { + return nil + } + + l.RUnlock() + l.Lock() + defer l.RLock() + defer l.Unlock() + + secret, updated, err := l.factory.AddCN(secret, append(l.sans, cn)...) + if err != nil { + return err + } + + if updated { + if err := l.storage.Update(secret); err != nil { + return err + } + // clear version to force cert reload + l.version = "" + } + + return nil +} + +func (l *listener) loadCert() (*tls.Certificate, error) { + l.RLock() + defer l.RUnlock() + + secret, err := l.storage.Get() + if err != nil { + return nil, err + } + if l.cert != nil && l.version == secret.ResourceVersion { + return l.cert, nil + } + + defer l.RLock() + l.RUnlock() + l.Lock() + defer l.Unlock() + + secret, err = l.storage.Get() + if err != nil { + return nil, err + } + if l.cert != nil && l.version == secret.ResourceVersion { + return l.cert, nil + } + + cert, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey]) + if err != nil { + return nil, err + } + + l.cert = &cert + return l.cert, nil +} + +func (l *listener) cacheHandler() http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + h, _, err := net.SplitHostPort(req.Host) + if err != nil { + h = req.Host + } + + ip := net.ParseIP(h) + if len(ip) > 0 { + l.updateCert(h) + } + }) +} + +type nonNil struct { + sync.Mutex + storage TLSStorage +} + +func (n *nonNil) Get() (*v1.Secret, error) { + n.Lock() + defer n.Unlock() + + s, err := n.storage.Get() + if err != nil || s == nil { + return &v1.Secret{}, err + } + return s, nil +} + +func (n *nonNil) Update(secret *v1.Secret) error { + n.Lock() + defer n.Unlock() + + return n.storage.Update(secret) +} diff --git a/read.go b/read.go deleted file mode 100644 index bf99514..0000000 --- a/read.go +++ /dev/null @@ -1,50 +0,0 @@ -package dynamiclistener - -import ( - "fmt" - "io/ioutil" - "path/filepath" -) - -func ReadTLSConfig(userConfig *UserConfig) error { - var err error - - path := userConfig.CertPath - - userConfig.CACerts, err = readPEM(filepath.Join(path, "cacerts.pem")) - if err != nil { - return err - } - - userConfig.Key, err = readPEM(filepath.Join(path, "key.pem")) - if err != nil { - return err - } - - userConfig.Cert, err = readPEM(filepath.Join(path, "cert.pem")) - if err != nil { - return err - } - - valid := false - if userConfig.Key != "" && userConfig.Cert != "" { - valid = true - } else if userConfig.Key == "" && userConfig.Cert == "" { - valid = true - } - - if !valid { - return fmt.Errorf("invalid SSL configuration found, please set cert/key, cert/key/cacerts, cacerts only, or none") - } - - return nil -} - -func readPEM(path string) (string, error) { - content, err := ioutil.ReadFile(path) - if err != nil { - return "", nil - } - - return string(content), nil -} diff --git a/redirect.go b/redirect.go new file mode 100644 index 0000000..264efde --- /dev/null +++ b/redirect.go @@ -0,0 +1,46 @@ +package dynamiclistener + +import ( + "fmt" + "net" + "net/http" + "strconv" + "strings" +) + +// Approach taken from letsencrypt, except manglePort is specific to us +func HTTPRedirect(next http.Handler) http.Handler { + return http.HandlerFunc( + func(rw http.ResponseWriter, r *http.Request) { + fmt.Println("!!!!!", r.URL.String(), r.Header) + if r.Header.Get("x-Forwarded-Proto") == "https" || + r.Header.Get("x-Forwarded-Proto") == "wss" || + strings.HasPrefix(r.URL.Path, "/ping") || + strings.HasPrefix(r.URL.Path, "/health") { + next.ServeHTTP(rw, r) + return + } + if r.Method != "GET" && r.Method != "HEAD" { + http.Error(rw, "Use HTTPS", http.StatusBadRequest) + return + } + target := "https://" + manglePort(r.Host) + r.URL.RequestURI() + http.Redirect(rw, r, target, http.StatusFound) + }) +} + +func manglePort(hostport string) string { + host, port, err := net.SplitHostPort(hostport) + if err != nil { + return hostport + } + + portInt, err := strconv.Atoi(port) + if err != nil { + return hostport + } + + portInt = ((portInt / 1000) * 1000) + 443 + + return net.JoinHostPort(host, strconv.Itoa(portInt)) +} diff --git a/server.go b/server.go deleted file mode 100644 index 2c01dbf..0000000 --- a/server.go +++ /dev/null @@ -1,556 +0,0 @@ -package dynamiclistener - -import ( - "context" - "crypto" - "crypto/ecdsa" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/pem" - "errors" - "fmt" - "log" - "net" - "net/http" - "reflect" - "strconv" - "strings" - "sync" - "time" - - cert "github.com/rancher/dynamiclistener/cert" - "github.com/sirupsen/logrus" -) - -type server struct { - sync.Mutex - - userConfig UserConfig - listenConfigStorage ListenerConfigStorage - tlsCert *tls.Certificate - - ips map[string]bool - domains map[string]bool - cn string - - listeners []net.Listener - servers []*http.Server - - activeCA *x509.Certificate - activeCAKey crypto.Signer -} - -func NewServer(listenConfigStorage ListenerConfigStorage, config UserConfig) (ServerInterface, error) { - s := &server{ - userConfig: config, - listenConfigStorage: listenConfigStorage, - cn: "cattle", - } - - s.ips = map[string]bool{} - s.domains = map[string]bool{} - - if err := s.userConfigure(); err != nil { - return nil, err - } - - lc, err := listenConfigStorage.Get() - if err != nil { - return nil, err - } - - return s, s.Update(lc) -} - -func (s *server) CACert() (string, error) { - if s.userConfig.NoCACerts { - return "", nil - } - if s.userConfig.CACerts != "" { - return s.userConfig.CACerts, nil - } - return "", fmt.Errorf("ca cert not found") -} - -func marshalPrivateKey(privateKey crypto.Signer) (string, []byte, error) { - var ( - keyType string - bytes []byte - err error - ) - if key, ok := privateKey.(*ecdsa.PrivateKey); ok { - keyType = cert.ECPrivateKeyBlockType - bytes, err = x509.MarshalECPrivateKey(key) - } else if key, ok := privateKey.(*rsa.PrivateKey); ok { - keyType = cert.RSAPrivateKeyBlockType - bytes = x509.MarshalPKCS1PrivateKey(key) - } else { - keyType = cert.PrivateKeyBlockType - bytes, err = x509.MarshalPKCS8PrivateKey(privateKey) - } - if err != nil { - logrus.Errorf("Unable to marshal private key: %v", err) - } - return keyType, bytes, err -} - -func newPrivateKey() (crypto.Signer, error) { - caKeyBytes, err := cert.MakeEllipticPrivateKeyPEM() - if err != nil { - return nil, err - } - caKeyIFace, err := cert.ParsePrivateKeyPEM(caKeyBytes) - if err != nil { - return nil, err - } - return caKeyIFace.(crypto.Signer), nil -} - -func (s *server) save() (_err error) { - defer func() { - if _err != nil { - logrus.Errorf("Saving cert error: %s", _err) - } - }() - - certStr, err := certToString(s.tlsCert) - if err != nil { - return err - } - cfg, err := s.listenConfigStorage.Get() - if err != nil { - return err - } - cfg.GeneratedCerts = map[string]string{s.cn: certStr} - - _, err = s.listenConfigStorage.Set(cfg) - return err -} - -func (s *server) userConfigure() error { - if s.userConfig.HTTPSPort == 0 { - s.userConfig.HTTPSPort = 8443 - } - - for _, d := range s.userConfig.Domains { - s.domains[d] = true - } - - for _, ip := range s.userConfig.KnownIPs { - if netIP := net.ParseIP(ip); netIP != nil { - s.ips[ip] = true - } - } - - if bindAddress := net.ParseIP(s.userConfig.BindAddress); bindAddress != nil { - s.ips[s.userConfig.BindAddress] = true - } - - if s.activeCA == nil && s.activeCAKey == nil { - if s.userConfig.CACerts != "" && s.userConfig.CAKey != "" { - ca, err := cert.ParseCertsPEM([]byte(s.userConfig.CACerts)) - if err != nil { - return err - } - key, err := cert.ParsePrivateKeyPEM([]byte(s.userConfig.CAKey)) - if err != nil { - return err - } - s.activeCA = ca[0] - s.activeCAKey = key.(crypto.Signer) - } else { - ca, key, err := genCA() - if err != nil { - return err - } - s.activeCA = ca - s.activeCAKey = key - } - } - - return nil -} - -func genCA() (*x509.Certificate, crypto.Signer, error) { - caKey, err := newPrivateKey() - if err != nil { - return nil, nil, err - } - - caCert, err := cert.NewSelfSignedCACert(cert.Config{ - CommonName: "k3s-ca", - Organization: []string{"k3s-org"}, - }, caKey) - if err != nil { - return nil, nil, err - } - - return caCert, caKey, nil -} - -func (s *server) Update(status *ListenerStatus) (_err error) { - s.Lock() - defer func() { - s.Unlock() - if _err != nil { - logrus.Errorf("Update cert error: %s", _err) - } - if s.tlsCert == nil { - s.getCertificate(&tls.ClientHelloInfo{ServerName: "localhost"}) - } - }() - - certString := status.GeneratedCerts[s.cn] - tlsCert, err := stringToCert(certString) - if err != nil { - logrus.Errorf("Update cert unable to convert string to cert: %s", err) - s.tlsCert = nil - } - if tlsCert != nil { - s.tlsCert = tlsCert - for i, certBytes := range tlsCert.Certificate { - parsedCert, err := x509.ParseCertificate(certBytes) - if err != nil { - logrus.Errorf("Update cert %d parse error: %s", i, err) - s.tlsCert = nil - break - } - isExpired := cert.IsCertExpired(parsedCert) - if isExpired { - logrus.Infof("certificate is about to expire") - s.tlsCert = nil - break - } - ips := map[string]bool{} - for _, ip := range parsedCert.IPAddresses { - ips[ip.String()] = true - } - - domains := map[string]bool{} - for _, domain := range parsedCert.DNSNames { - domains[domain] = true - } - - if !(reflect.DeepEqual(ips, s.ips) && reflect.DeepEqual(domains, s.domains)) { - subset := true - for ip := range s.ips { - if !ips[ip] { - subset = false - break - } - } - if subset { - for domain := range s.domains { - if !domains[domain] { - subset = false - break - } - } - } - if !subset { - s.tlsCert = nil - } - for ip := range ips { - s.ips[ip] = true - } - for domain := range domains { - s.domains[domain] = true - } - } - } - } - - return s.reload() -} - -func (s *server) shutdown() error { - for _, listener := range s.listeners { - if err := listener.Close(); err != nil { - return err - } - } - s.listeners = nil - - for _, server := range s.servers { - go server.Shutdown(context.Background()) - } - s.servers = nil - - return nil -} - -func (s *server) reload() error { - if len(s.listeners) > 0 { - return nil - } - - if err := s.shutdown(); err != nil { - return err - } - - if err := s.serveHTTPS(); err != nil { - return err - } - - return nil -} - -func (s *server) getCertificate(hello *tls.ClientHelloInfo) (_servingCert *tls.Certificate, _err error) { - s.Lock() - changed := false - - defer func() { - defer s.Unlock() - - if _err != nil { - logrus.Errorf("Get certificate error: %s", _err) - return - } - - if changed { - s.save() - } - }() - - if hello.ServerName != "" && !s.domains[hello.ServerName] { - s.tlsCert = nil - s.domains[hello.ServerName] = true - } - - if s.tlsCert != nil { - return s.tlsCert, nil - } - - ips := []net.IP{} - for ipStr := range s.ips { - if ip := net.ParseIP(ipStr); ip != nil { - ips = append(ips, ip) - } - } - - dnsNames := []string{} - for domain := range s.domains { - dnsNames = append(dnsNames, domain) - } - - cfg := cert.Config{ - CommonName: s.cn, - Organization: s.activeCA.Subject.Organization, - Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - AltNames: cert.AltNames{ - DNSNames: dnsNames, - IPs: ips, - }, - } - - key, err := newPrivateKey() - if err != nil { - return nil, err - } - - cert, err := cert.NewSignedCert(cfg, key, s.activeCA, s.activeCAKey) - if err != nil { - return nil, err - } - - tlsCert := &tls.Certificate{ - Certificate: [][]byte{ - cert.Raw, - }, - PrivateKey: key, - } - - changed = true - s.tlsCert = tlsCert - return tlsCert, nil -} - -func (s *server) cacheHandler(handler http.Handler) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - h, _, err := net.SplitHostPort(req.Host) - if err != nil { - h = req.Host - } - - s.Lock() - if ip := net.ParseIP(h); ip != nil { - if !s.ips[h] { - s.ips[h] = true - s.tlsCert = nil - } - } else { - if !s.domains[h] { - s.domains[h] = true - s.tlsCert = nil - } - } - s.Unlock() - - handler.ServeHTTP(resp, req) - }) -} - -func (s *server) serveHTTPS() error { - conf := &tls.Config{ - ClientAuth: tls.RequestClientCert, - GetCertificate: s.getCertificate, - PreferServerCipherSuites: true, - } - - listener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPSPort, conf) - if err != nil { - return err - } - - logger := logrus.StandardLogger() - server := &http.Server{ - Handler: s.cacheHandler(s.Handler()), - ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags), - } - - s.servers = append(s.servers, server) - s.startServer(listener, server) - - if s.userConfig.HTTPPort > 0 { - httpListener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPPort, nil) - if err != nil { - return err - } - - httpServer := &http.Server{ - Handler: s.cacheHandler(httpRedirect(s.Handler())), - ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags), - } - - s.servers = append(s.servers, httpServer) - s.startServer(httpListener, httpServer) - } - - return nil -} - -// Approach taken from letsencrypt, except manglePort is specific to us -func httpRedirect(next http.Handler) http.Handler { - return http.HandlerFunc( - func(rw http.ResponseWriter, r *http.Request) { - if r.Header.Get("x-Forwarded-Proto") == "https" || - strings.HasPrefix(r.URL.Path, "/ping") || - strings.HasPrefix(r.URL.Path, "/health") { - next.ServeHTTP(rw, r) - return - } - if r.Method != "GET" && r.Method != "HEAD" { - http.Error(rw, "Use HTTPS", http.StatusBadRequest) - return - } - target := "https://" + manglePort(r.Host) + r.URL.RequestURI() - http.Redirect(rw, r, target, http.StatusFound) - }) -} - -func manglePort(hostport string) string { - host, port, err := net.SplitHostPort(hostport) - if err != nil { - return hostport - } - - portInt, err := strconv.Atoi(port) - if err != nil { - return hostport - } - - portInt = ((portInt / 1000) * 1000) + 443 - - return net.JoinHostPort(host, strconv.Itoa(portInt)) -} - -func (s *server) startServer(listener net.Listener, server *http.Server) { - go func() { - if err := server.Serve(listener); err != nil { - logrus.Errorf("server on %v returned err: %v", listener.Addr(), err) - } - }() -} - -func (s *server) Handler() http.Handler { - return s.userConfig.Handler -} - -func (s *server) newListener(ip string, port int, config *tls.Config) (net.Listener, error) { - addr := fmt.Sprintf("%s:%d", ip, port) - l, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - - l = tcpKeepAliveListener{l.(*net.TCPListener)} - - if config != nil { - l = tls.NewListener(l, config) - } - - s.listeners = append(s.listeners, l) - logrus.Info("Listening on ", addr) - return l, nil -} - -func stringToCert(certString string) (*tls.Certificate, error) { - parts := strings.Split(certString, "#") - if len(parts) != 2 { - return nil, errors.New("Unable to split cert into two parts") - } - - certPart, keyPart := parts[0], parts[1] - keyBytes, err := base64.StdEncoding.DecodeString(keyPart) - if err != nil { - return nil, err - } - - key, err := cert.ParsePrivateKeyPEM(keyBytes) - if err != nil { - return nil, err - } - - certBytes, err := base64.StdEncoding.DecodeString(certPart) - if err != nil { - return nil, err - } - - return &tls.Certificate{ - Certificate: [][]byte{certBytes}, - PrivateKey: key, - }, nil -} - -func certToString(cert *tls.Certificate) (string, error) { - keyType, keyBytes, err := marshalPrivateKey(cert.PrivateKey.(crypto.Signer)) - if err != nil { - return "", err - } - - privateKeyPemBlock := &pem.Block{ - Type: keyType, - Bytes: keyBytes, - } - pemBytes := pem.EncodeToMemory(privateKeyPemBlock) - - certString := base64.StdEncoding.EncodeToString(cert.Certificate[0]) - keyString := base64.StdEncoding.EncodeToString(pemBytes) - return certString + "#" + keyString, nil -} - -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - tc, err := ln.AcceptTCP() - if err != nil { - return - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - return tc, nil -} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..cf069df --- /dev/null +++ b/server/server.go @@ -0,0 +1,80 @@ +package server + +import ( + "context" + "fmt" + "net/http" + + "github.com/rancher/dynamiclistener" + "github.com/rancher/dynamiclistener/factory" + "github.com/rancher/dynamiclistener/storage/memory" + "github.com/sirupsen/logrus" +) + +func ListenAndServe(ctx context.Context, httpsPort, httpPort int, handler http.Handler) error { + var ( + // https listener will change this if http is enabled + targetHandler = handler + ) + if httpsPort > 0 { + caCert, caKey, err := factory.LoadOrGenCA() + if err != nil { + return err + } + + tlsTCPListener, err := dynamiclistener.NewTCPListener("0.0.0.0", httpsPort) + if err != nil { + return err + } + + dynListener, dynHandler, err := dynamiclistener.NewListener(tlsTCPListener, memory.New(), caCert, caKey, dynamiclistener.Config{}) + if err != nil { + return err + } + + targetHandler = wrapHandler(dynHandler, handler) + tlsServer := http.Server{ + Handler: targetHandler, + } + targetHandler = dynamiclistener.HTTPRedirect(targetHandler) + + go func() { + logrus.Infof("Listening on 0.0.0.0:%d", httpsPort) + err := tlsServer.Serve(dynListener) + if err != http.ErrServerClosed && err != nil { + logrus.Fatalf("https server failed: %v", err) + } + }() + go func() { + <-ctx.Done() + tlsServer.Shutdown(context.Background()) + }() + } + + if httpPort > 0 { + httpServer := http.Server{ + Addr: fmt.Sprintf("0.0.0.0:%d", httpPort), + Handler: targetHandler, + } + go func() { + logrus.Infof("Listening on 0.0.0.0:%d", httpPort) + err := httpServer.ListenAndServe() + if err != http.ErrServerClosed && err != nil { + logrus.Fatalf("http server failed: %v", err) + } + }() + go func() { + <-ctx.Done() + httpServer.Shutdown(context.Background()) + }() + } + + return nil +} + +func wrapHandler(handler http.Handler, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + handler.ServeHTTP(rw, req) + next.ServeHTTP(rw, req) + }) +} diff --git a/storage/file/file.go b/storage/file/file.go new file mode 100644 index 0000000..ba738bc --- /dev/null +++ b/storage/file/file.go @@ -0,0 +1,42 @@ +package file + +import ( + "encoding/json" + "github.com/rancher/dynamiclistener" + "k8s.io/api/core/v1" + "os" +) + +func New(file string) dynamiclistener.TLSStorage { + return &storage{ + file: file, + } +} + +type storage struct { + file string +} + +func (s *storage) Get() (*v1.Secret, error) { + f, err := os.Open(s.file) + if os.IsNotExist(err) { + return nil, nil + } else if err != nil { + return nil, err + } + defer f.Close() + + secret := v1.Secret{} + return &secret, json.NewDecoder(f).Decode(&secret) +} + +func (s *storage) Update(secret *v1.Secret) error { + f, err := os.Create(s.file) + if err != nil { + return err + } + defer f.Close() + + return json.NewEncoder(f).Encode(secret) +} + diff --git a/storage/kubernetes/controller.go b/storage/kubernetes/controller.go new file mode 100644 index 0000000..d5198a2 --- /dev/null +++ b/storage/kubernetes/controller.go @@ -0,0 +1,109 @@ +package kubernetes + +import ( + "context" + "sync" + "time" + + "github.com/rancher/dynamiclistener" + "github.com/rancher/wrangler-api/pkg/generated/controllers/core" + v1controller "github.com/rancher/wrangler-api/pkg/generated/controllers/core/v1" + "github.com/rancher/wrangler/pkg/start" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/equality" +) + +type CoreGetter func() *core.Factory + +func New(ctx context.Context, core CoreGetter, namespace, name string, backing dynamiclistener.TLSStorage) dynamiclistener.TLSStorage { + storage := &storage{ + name: name, + namespace: namespace, + storage: backing, + ctx: ctx, + } + + // lazy init + go func() { + for { + core := core() + if core != nil { + storage.init(core.Core().V1().Secret()) + start.All(ctx, 5, core) + return + } + + select { + case <-ctx.Done(): + return + case <-time.After(time.Second): + } + } + }() + + return storage +} + +type storage struct { + sync.Mutex + + namespace, name string + storage dynamiclistener.TLSStorage + secrets v1controller.SecretClient + ctx context.Context +} + +func (s *storage) init(secrets v1controller.SecretController) { + s.Lock() + defer s.Unlock() + + secrets.OnChange(s.ctx, "tls-storage", func(key string, secret *v1.Secret) (*v1.Secret, error) { + if secret == nil { + return nil, nil + } + if secret.Namespace == s.namespace && secret.Name == s.name { + if err := s.Update(secret); err != nil { + return nil, err + } + } + + return secret, nil + }) + s.secrets = secrets +} + +func (s *storage) Get() (*v1.Secret, error) { + s.Lock() + defer s.Unlock() + + return s.storage.Get() +} + +func (s *storage) Update(secret *v1.Secret) (err error) { + s.Lock() + defer s.Unlock() + + if s.secrets != nil { + if secret.UID == "" { + secret.Name = s.name + secret.Namespace = s.namespace + secret, err = s.secrets.Create(secret) + if err != nil { + return err + } + } else { + existingSecret, err := s.storage.Get() + if err != nil { + return err + } + if !equality.Semantic.DeepEqual(secret.Data, existingSecret.Data) { + secret, err = s.secrets.Update(secret) + if err != nil { + return err + } + } + } + } + + return s.storage.Update(secret) +} diff --git a/storage/memory/memory.go b/storage/memory/memory.go new file mode 100644 index 0000000..079f180 --- /dev/null +++ b/storage/memory/memory.go @@ -0,0 +1,42 @@ +package memory + +import ( + "github.com/rancher/dynamiclistener" + v1 "k8s.io/api/core/v1" +) + +func New() dynamiclistener.TLSStorage { + return &memory{} +} + +func NewBacked(storage dynamiclistener.TLSStorage) dynamiclistener.TLSStorage { + return &memory{storage: storage} +} + +type memory struct { + storage dynamiclistener.TLSStorage + secret *v1.Secret +} + +func (m *memory) Get() (*v1.Secret, error) { + if m.secret == nil && m.storage != nil { + secret, err := m.storage.Get() + if err != nil { + return nil, err + } + m.secret = secret + } + + return m.secret, nil +} + +func (m *memory) Update(secret *v1.Secret) error { + if m.storage != nil { + if err := m.storage.Update(secret); err != nil { + return err + } + } + + m.secret = secret + return nil +} diff --git a/tcp.go b/tcp.go new file mode 100644 index 0000000..3f41588 --- /dev/null +++ b/tcp.go @@ -0,0 +1,38 @@ +package dynamiclistener + +import ( + "fmt" + "net" + "reflect" + "time" +) + +func NewTCPListener(ip string, port int) (net.Listener, error) { + l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", ip, port)) + if err != nil { + return nil, err + } + + tcpListener, ok := l.(*net.TCPListener) + if !ok { + return nil, fmt.Errorf("wrong listener type: %v", reflect.TypeOf(tcpListener)) + } + + return tcpKeepAliveListener{ + TCPListener: tcpListener, + }, nil +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} diff --git a/types.go b/types.go deleted file mode 100644 index 7fabf1c..0000000 --- a/types.go +++ /dev/null @@ -1,63 +0,0 @@ -package dynamiclistener - -import ( - "net/http" -) - -type ListenerConfigStorage interface { - Set(*ListenerStatus) (*ListenerStatus, error) - Get() (*ListenerStatus, error) -} - -type ServerInterface interface { - Update(status *ListenerStatus) error - CACert() (string, error) -} - -type UserConfig struct { - // Required fields - - Handler http.Handler - HTTPPort int - HTTPSPort int - CertPath string - - // Optional fields - - KnownIPs []string - Domains []string - Mode string - NoCACerts bool - CACerts string - CAKey string - Cert string - Key string - BindAddress string -} - -type ListenerStatus struct { - Revision string `json:"revision,omitempty"` - CACert string `json:"caCert,omitempty"` - CAKey string `json:"caKey,omitempty"` - GeneratedCerts map[string]string `json:"generatedCerts" norman:"nocreate,noupdate"` - KnownIPs map[string]bool `json:"knownIps" norman:"nocreate,noupdate"` -} - -func (l *ListenerStatus) DeepCopyInto(t *ListenerStatus) { - t.Revision = l.Revision - t.CACert = l.CACert - t.CAKey = l.CAKey - t.GeneratedCerts = copyMap(t.GeneratedCerts) - t.KnownIPs = map[string]bool{} - for k, v := range l.KnownIPs { - t.KnownIPs[k] = v - } -} - -func copyMap(m map[string]string) map[string]string { - ret := map[string]string{} - for k, v := range m { - ret[k] = v - } - return ret -}