mirror of
https://github.com/rancher/dynamiclistener.git
synced 2025-06-03 20:30:13 +00:00
557 lines
12 KiB
Go
557 lines
12 KiB
Go
package dynamiclistener
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
cert "github.com/rancher/dynamiclistener/cert"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type server struct {
|
|
sync.Mutex
|
|
|
|
userConfig UserConfig
|
|
listenConfigStorage ListenerConfigStorage
|
|
tlsCert *tls.Certificate
|
|
|
|
ips map[string]bool
|
|
domains map[string]bool
|
|
cn string
|
|
|
|
listeners []net.Listener
|
|
servers []*http.Server
|
|
|
|
activeCA *x509.Certificate
|
|
activeCAKey crypto.Signer
|
|
}
|
|
|
|
func NewServer(listenConfigStorage ListenerConfigStorage, config UserConfig) (ServerInterface, error) {
|
|
s := &server{
|
|
userConfig: config,
|
|
listenConfigStorage: listenConfigStorage,
|
|
cn: "cattle",
|
|
}
|
|
|
|
s.ips = map[string]bool{}
|
|
s.domains = map[string]bool{}
|
|
|
|
if err := s.userConfigure(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
lc, err := listenConfigStorage.Get()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, s.Update(lc)
|
|
}
|
|
|
|
func (s *server) CACert() (string, error) {
|
|
if s.userConfig.NoCACerts {
|
|
return "", nil
|
|
}
|
|
if s.userConfig.CACerts != "" {
|
|
return s.userConfig.CACerts, nil
|
|
}
|
|
return "", fmt.Errorf("ca cert not found")
|
|
}
|
|
|
|
func marshalPrivateKey(privateKey crypto.Signer) (string, []byte, error) {
|
|
var (
|
|
keyType string
|
|
bytes []byte
|
|
err error
|
|
)
|
|
if key, ok := privateKey.(*ecdsa.PrivateKey); ok {
|
|
keyType = cert.ECPrivateKeyBlockType
|
|
bytes, err = x509.MarshalECPrivateKey(key)
|
|
} else if key, ok := privateKey.(*rsa.PrivateKey); ok {
|
|
keyType = cert.RSAPrivateKeyBlockType
|
|
bytes = x509.MarshalPKCS1PrivateKey(key)
|
|
} else {
|
|
keyType = cert.PrivateKeyBlockType
|
|
bytes, err = x509.MarshalPKCS8PrivateKey(privateKey)
|
|
}
|
|
if err != nil {
|
|
logrus.Errorf("Unable to marshal private key: %v", err)
|
|
}
|
|
return keyType, bytes, err
|
|
}
|
|
|
|
func newPrivateKey() (crypto.Signer, error) {
|
|
caKeyBytes, err := cert.MakeEllipticPrivateKeyPEM()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
caKeyIFace, err := cert.ParsePrivateKeyPEM(caKeyBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return caKeyIFace.(crypto.Signer), nil
|
|
}
|
|
|
|
func (s *server) save() (_err error) {
|
|
defer func() {
|
|
if _err != nil {
|
|
logrus.Errorf("Saving cert error: %s", _err)
|
|
}
|
|
}()
|
|
|
|
certStr, err := certToString(s.tlsCert)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg, err := s.listenConfigStorage.Get()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg.GeneratedCerts = map[string]string{s.cn: certStr}
|
|
|
|
_, err = s.listenConfigStorage.Set(cfg)
|
|
return err
|
|
}
|
|
|
|
func (s *server) userConfigure() error {
|
|
if s.userConfig.HTTPSPort == 0 {
|
|
s.userConfig.HTTPSPort = 8443
|
|
}
|
|
|
|
for _, d := range s.userConfig.Domains {
|
|
s.domains[d] = true
|
|
}
|
|
|
|
for _, ip := range s.userConfig.KnownIPs {
|
|
if netIP := net.ParseIP(ip); netIP != nil {
|
|
s.ips[ip] = true
|
|
}
|
|
}
|
|
|
|
if bindAddress := net.ParseIP(s.userConfig.BindAddress); bindAddress != nil {
|
|
s.ips[s.userConfig.BindAddress] = true
|
|
}
|
|
|
|
if s.activeCA == nil && s.activeCAKey == nil {
|
|
if s.userConfig.CACerts != "" && s.userConfig.CAKey != "" {
|
|
ca, err := cert.ParseCertsPEM([]byte(s.userConfig.CACerts))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
key, err := cert.ParsePrivateKeyPEM([]byte(s.userConfig.CAKey))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.activeCA = ca[0]
|
|
s.activeCAKey = key.(crypto.Signer)
|
|
} else {
|
|
ca, key, err := genCA()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.activeCA = ca
|
|
s.activeCAKey = key
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func genCA() (*x509.Certificate, crypto.Signer, error) {
|
|
caKey, err := newPrivateKey()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
caCert, err := cert.NewSelfSignedCACert(cert.Config{
|
|
CommonName: "k3s-ca",
|
|
Organization: []string{"k3s-org"},
|
|
}, caKey)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return caCert, caKey, nil
|
|
}
|
|
|
|
func (s *server) Update(status *ListenerStatus) (_err error) {
|
|
s.Lock()
|
|
defer func() {
|
|
s.Unlock()
|
|
if _err != nil {
|
|
logrus.Errorf("Update cert error: %s", _err)
|
|
}
|
|
if s.tlsCert == nil {
|
|
s.getCertificate(&tls.ClientHelloInfo{ServerName: "localhost"})
|
|
}
|
|
}()
|
|
|
|
certString := status.GeneratedCerts[s.cn]
|
|
tlsCert, err := stringToCert(certString)
|
|
if err != nil {
|
|
logrus.Errorf("Update cert unable to convert string to cert: %s", err)
|
|
s.tlsCert = nil
|
|
}
|
|
if tlsCert != nil {
|
|
s.tlsCert = tlsCert
|
|
for i, certBytes := range tlsCert.Certificate {
|
|
parsedCert, err := x509.ParseCertificate(certBytes)
|
|
if err != nil {
|
|
logrus.Errorf("Update cert %d parse error: %s", i, err)
|
|
s.tlsCert = nil
|
|
break
|
|
}
|
|
isExpired := cert.IsCertExpired(parsedCert)
|
|
if isExpired {
|
|
logrus.Infof("certificate is about to expire")
|
|
s.tlsCert = nil
|
|
break
|
|
}
|
|
ips := map[string]bool{}
|
|
for _, ip := range parsedCert.IPAddresses {
|
|
ips[ip.String()] = true
|
|
}
|
|
|
|
domains := map[string]bool{}
|
|
for _, domain := range parsedCert.DNSNames {
|
|
domains[domain] = true
|
|
}
|
|
|
|
if !(reflect.DeepEqual(ips, s.ips) && reflect.DeepEqual(domains, s.domains)) {
|
|
subset := true
|
|
for ip := range s.ips {
|
|
if !ips[ip] {
|
|
subset = false
|
|
break
|
|
}
|
|
}
|
|
if subset {
|
|
for domain := range s.domains {
|
|
if !domains[domain] {
|
|
subset = false
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if !subset {
|
|
s.tlsCert = nil
|
|
}
|
|
for ip := range ips {
|
|
s.ips[ip] = true
|
|
}
|
|
for domain := range domains {
|
|
s.domains[domain] = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return s.reload()
|
|
}
|
|
|
|
func (s *server) shutdown() error {
|
|
for _, listener := range s.listeners {
|
|
if err := listener.Close(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
s.listeners = nil
|
|
|
|
for _, server := range s.servers {
|
|
go server.Shutdown(context.Background())
|
|
}
|
|
s.servers = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *server) reload() error {
|
|
if len(s.listeners) > 0 {
|
|
return nil
|
|
}
|
|
|
|
if err := s.shutdown(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.serveHTTPS(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *server) getCertificate(hello *tls.ClientHelloInfo) (_servingCert *tls.Certificate, _err error) {
|
|
s.Lock()
|
|
changed := false
|
|
|
|
defer func() {
|
|
defer s.Unlock()
|
|
|
|
if _err != nil {
|
|
logrus.Errorf("Get certificate error: %s", _err)
|
|
return
|
|
}
|
|
|
|
if changed {
|
|
s.save()
|
|
}
|
|
}()
|
|
|
|
if hello.ServerName != "" && !s.domains[hello.ServerName] {
|
|
s.tlsCert = nil
|
|
s.domains[hello.ServerName] = true
|
|
}
|
|
|
|
if s.tlsCert != nil {
|
|
return s.tlsCert, nil
|
|
}
|
|
|
|
ips := []net.IP{}
|
|
for ipStr := range s.ips {
|
|
if ip := net.ParseIP(ipStr); ip != nil {
|
|
ips = append(ips, ip)
|
|
}
|
|
}
|
|
|
|
dnsNames := []string{}
|
|
for domain := range s.domains {
|
|
dnsNames = append(dnsNames, domain)
|
|
}
|
|
|
|
cfg := cert.Config{
|
|
CommonName: s.cn,
|
|
Organization: s.activeCA.Subject.Organization,
|
|
Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
AltNames: cert.AltNames{
|
|
DNSNames: dnsNames,
|
|
IPs: ips,
|
|
},
|
|
}
|
|
|
|
key, err := newPrivateKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cert, err := cert.NewSignedCert(cfg, key, s.activeCA, s.activeCAKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tlsCert := &tls.Certificate{
|
|
Certificate: [][]byte{
|
|
cert.Raw,
|
|
},
|
|
PrivateKey: key,
|
|
}
|
|
|
|
changed = true
|
|
s.tlsCert = tlsCert
|
|
return tlsCert, nil
|
|
}
|
|
|
|
func (s *server) cacheHandler(handler http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
|
h, _, err := net.SplitHostPort(req.Host)
|
|
if err != nil {
|
|
h = req.Host
|
|
}
|
|
|
|
s.Lock()
|
|
if ip := net.ParseIP(h); ip != nil {
|
|
if !s.ips[h] {
|
|
s.ips[h] = true
|
|
s.tlsCert = nil
|
|
}
|
|
} else {
|
|
if !s.domains[h] {
|
|
s.domains[h] = true
|
|
s.tlsCert = nil
|
|
}
|
|
}
|
|
s.Unlock()
|
|
|
|
handler.ServeHTTP(resp, req)
|
|
})
|
|
}
|
|
|
|
func (s *server) serveHTTPS() error {
|
|
conf := &tls.Config{
|
|
ClientAuth: tls.RequestClientCert,
|
|
GetCertificate: s.getCertificate,
|
|
PreferServerCipherSuites: true,
|
|
}
|
|
|
|
listener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPSPort, conf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
logger := logrus.StandardLogger()
|
|
server := &http.Server{
|
|
Handler: s.cacheHandler(s.Handler()),
|
|
ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags),
|
|
}
|
|
|
|
s.servers = append(s.servers, server)
|
|
s.startServer(listener, server)
|
|
|
|
if s.userConfig.HTTPPort > 0 {
|
|
httpListener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPPort, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
httpServer := &http.Server{
|
|
Handler: s.cacheHandler(httpRedirect(s.Handler())),
|
|
ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags),
|
|
}
|
|
|
|
s.servers = append(s.servers, httpServer)
|
|
s.startServer(httpListener, httpServer)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Approach taken from letsencrypt, except manglePort is specific to us
|
|
func httpRedirect(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(
|
|
func(rw http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("x-Forwarded-Proto") == "https" ||
|
|
strings.HasPrefix(r.URL.Path, "/ping") ||
|
|
strings.HasPrefix(r.URL.Path, "/health") {
|
|
next.ServeHTTP(rw, r)
|
|
return
|
|
}
|
|
if r.Method != "GET" && r.Method != "HEAD" {
|
|
http.Error(rw, "Use HTTPS", http.StatusBadRequest)
|
|
return
|
|
}
|
|
target := "https://" + manglePort(r.Host) + r.URL.RequestURI()
|
|
http.Redirect(rw, r, target, http.StatusFound)
|
|
})
|
|
}
|
|
|
|
func manglePort(hostport string) string {
|
|
host, port, err := net.SplitHostPort(hostport)
|
|
if err != nil {
|
|
return hostport
|
|
}
|
|
|
|
portInt, err := strconv.Atoi(port)
|
|
if err != nil {
|
|
return hostport
|
|
}
|
|
|
|
portInt = ((portInt / 1000) * 1000) + 443
|
|
|
|
return net.JoinHostPort(host, strconv.Itoa(portInt))
|
|
}
|
|
|
|
func (s *server) startServer(listener net.Listener, server *http.Server) {
|
|
go func() {
|
|
if err := server.Serve(listener); err != nil {
|
|
logrus.Errorf("server on %v returned err: %v", listener.Addr(), err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *server) Handler() http.Handler {
|
|
return s.userConfig.Handler
|
|
}
|
|
|
|
func (s *server) newListener(ip string, port int, config *tls.Config) (net.Listener, error) {
|
|
addr := fmt.Sprintf("%s:%d", ip, port)
|
|
l, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
l = tcpKeepAliveListener{l.(*net.TCPListener)}
|
|
|
|
if config != nil {
|
|
l = tls.NewListener(l, config)
|
|
}
|
|
|
|
s.listeners = append(s.listeners, l)
|
|
logrus.Info("Listening on ", addr)
|
|
return l, nil
|
|
}
|
|
|
|
func stringToCert(certString string) (*tls.Certificate, error) {
|
|
parts := strings.Split(certString, "#")
|
|
if len(parts) != 2 {
|
|
return nil, errors.New("Unable to split cert into two parts")
|
|
}
|
|
|
|
certPart, keyPart := parts[0], parts[1]
|
|
keyBytes, err := base64.StdEncoding.DecodeString(keyPart)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
key, err := cert.ParsePrivateKeyPEM(keyBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
certBytes, err := base64.StdEncoding.DecodeString(certPart)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &tls.Certificate{
|
|
Certificate: [][]byte{certBytes},
|
|
PrivateKey: key,
|
|
}, nil
|
|
}
|
|
|
|
func certToString(cert *tls.Certificate) (string, error) {
|
|
keyType, keyBytes, err := marshalPrivateKey(cert.PrivateKey.(crypto.Signer))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
privateKeyPemBlock := &pem.Block{
|
|
Type: keyType,
|
|
Bytes: keyBytes,
|
|
}
|
|
pemBytes := pem.EncodeToMemory(privateKeyPemBlock)
|
|
|
|
certString := base64.StdEncoding.EncodeToString(cert.Certificate[0])
|
|
keyString := base64.StdEncoding.EncodeToString(pemBytes)
|
|
return certString + "#" + keyString, nil
|
|
}
|
|
|
|
type tcpKeepAliveListener struct {
|
|
*net.TCPListener
|
|
}
|
|
|
|
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
|
tc, err := ln.AcceptTCP()
|
|
if err != nil {
|
|
return
|
|
}
|
|
tc.SetKeepAlive(true)
|
|
tc.SetKeepAlivePeriod(3 * time.Minute)
|
|
return tc, nil
|
|
}
|