forked from github/dynamiclistener
Refactor to not include a server by default
This commit is contained in:
parent
8a2488bc86
commit
af04867843
80
factory/ca.go
Normal file
80
factory/ca.go
Normal file
@ -0,0 +1,80 @@
|
||||
package factory
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
func GenCA() (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||
caKey, err := NewPrivateKey()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
caCert, err := NewSelfSignedCACert(caKey, "dynamiclistener-ca", "dynamiclistener-org")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return caCert, caKey, nil
|
||||
}
|
||||
|
||||
func LoadOrGenCA() (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||
cert, key, err := loadCA()
|
||||
if err == nil {
|
||||
return cert, key, nil
|
||||
}
|
||||
|
||||
cert, key, err = GenCA()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
certBytes, keyBytes, err := Marshal(cert, key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll("./certs", 0700); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile("./certs/ca.pem", certBytes, 0600); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile("./certs/ca.key", keyBytes, 0600); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return cert, key, nil
|
||||
}
|
||||
|
||||
func loadCA() (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||
return LoadCerts("./certs/ca.pem", "./certs/ca.key")
|
||||
}
|
||||
|
||||
func LoadCerts(certFile, keyFile string) (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||
caPem, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
caKey, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
key, err := ParseECPrivateKeyPEM(caKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cert, err := ParseCertPEM(caPem)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return cert, key, nil
|
||||
}
|
105
factory/cert_utils.go
Normal file
105
factory/cert_utils.go
Normal file
@ -0,0 +1,105 @@
|
||||
package factory
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ECPrivateKeyBlockType = "EC PRIVATE KEY"
|
||||
CertificateBlockType = "CERTIFICATE"
|
||||
)
|
||||
|
||||
func NewSelfSignedCACert(key crypto.Signer, cn string, org ...string) (*x509.Certificate, error) {
|
||||
now := time.Now()
|
||||
tmpl := x509.Certificate{
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
NotAfter: now.Add(time.Hour * 24 * 365 * 10).UTC(),
|
||||
NotBefore: now.UTC(),
|
||||
SerialNumber: new(big.Int).SetInt64(0),
|
||||
Subject: pkix.Name{
|
||||
CommonName: cn,
|
||||
Organization: org,
|
||||
},
|
||||
}
|
||||
|
||||
certDERBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, key.Public(), key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return x509.ParseCertificate(certDERBytes)
|
||||
}
|
||||
|
||||
func NewSignedCert(signer crypto.Signer, caCert *x509.Certificate, caKey crypto.Signer, cn string, orgs []string,
|
||||
domains []string, ips []net.IP) (*x509.Certificate, error) {
|
||||
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).SetInt64(math.MaxInt64))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parent := x509.Certificate{
|
||||
DNSNames: domains,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
IPAddresses: ips,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 365).UTC(),
|
||||
NotBefore: caCert.NotBefore,
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: cn,
|
||||
Organization: orgs,
|
||||
},
|
||||
}
|
||||
|
||||
cert, err := x509.CreateCertificate(rand.Reader, &parent, caCert, signer.Public(), caKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return x509.ParseCertificate(cert)
|
||||
}
|
||||
|
||||
func ParseECPrivateKeyPEM(keyData []byte) (*ecdsa.PrivateKey, error) {
|
||||
var privateKeyPemBlock *pem.Block
|
||||
for {
|
||||
privateKeyPemBlock, keyData = pem.Decode(keyData)
|
||||
if privateKeyPemBlock == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if privateKeyPemBlock.Type == ECPrivateKeyBlockType {
|
||||
return x509.ParseECPrivateKey(privateKeyPemBlock.Bytes)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("pem does not include a valid EC private key")
|
||||
}
|
||||
|
||||
func ParseCertPEM(pemCerts []byte) (*x509.Certificate, error) {
|
||||
var pemBlock *pem.Block
|
||||
for {
|
||||
pemBlock, pemCerts = pem.Decode(pemCerts)
|
||||
if pemBlock == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if pemBlock.Type == CertificateBlockType {
|
||||
return x509.ParseCertificate(pemBlock.Bytes)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("pem does not include a valid x509 cert")
|
||||
}
|
164
factory/gen.go
Normal file
164
factory/gen.go
Normal file
@ -0,0 +1,164 @@
|
||||
package factory
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
v1 "k8s.io/api/core/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
cnPrefix = "listener.cattle.io/cn-"
|
||||
static = "listener.cattle.io/static"
|
||||
hashKey = "listener.cattle.io/hash"
|
||||
)
|
||||
|
||||
type TLS struct {
|
||||
CACert *x509.Certificate
|
||||
CAKey crypto.Signer
|
||||
CN string
|
||||
Organization []string
|
||||
}
|
||||
|
||||
func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, err error) {
|
||||
var (
|
||||
cns []string
|
||||
digest = sha256.New()
|
||||
)
|
||||
for k, v := range secret.Annotations {
|
||||
if strings.HasPrefix(k, cnPrefix) {
|
||||
cns = append(cns, v)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(cns)
|
||||
|
||||
for _, v := range cns {
|
||||
digest.Write([]byte(v))
|
||||
ip := net.ParseIP(v)
|
||||
if ip == nil {
|
||||
domains = append(domains, v)
|
||||
} else {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
hash = hex.EncodeToString(digest.Sum(nil))
|
||||
return
|
||||
}
|
||||
|
||||
func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
|
||||
var (
|
||||
err error
|
||||
)
|
||||
|
||||
if !NeedsUpdate(secret, cn...) {
|
||||
return secret, false, nil
|
||||
}
|
||||
|
||||
secret = populateCN(secret, cn...)
|
||||
|
||||
privateKey, err := getPrivateKey(secret)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
domains, ips, hash, err := collectCNs(secret)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
newCert, err := t.newCert(domains, ips, privateKey)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
certBytes, keyBytes, err := Marshal(newCert, privateKey)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if secret.Data == nil {
|
||||
secret.Data = map[string][]byte{}
|
||||
}
|
||||
secret.Data[v1.TLSCertKey] = certBytes
|
||||
secret.Data[v1.TLSPrivateKeyKey] = keyBytes
|
||||
secret.Annotations[hashKey] = hash
|
||||
|
||||
return secret, true, nil
|
||||
}
|
||||
|
||||
func (t *TLS) newCert(domains []string, ips []net.IP, privateKey *ecdsa.PrivateKey) (*x509.Certificate, error) {
|
||||
return NewSignedCert(privateKey, t.CACert, t.CAKey, t.CN, t.Organization, domains, ips)
|
||||
}
|
||||
|
||||
func populateCN(secret *v1.Secret, cn ...string) *v1.Secret {
|
||||
secret = secret.DeepCopy()
|
||||
if secret.Annotations == nil {
|
||||
secret.Annotations = map[string]string{}
|
||||
}
|
||||
for _, cn := range cn {
|
||||
secret.Annotations[cnPrefix+cn] = cn
|
||||
}
|
||||
return secret
|
||||
}
|
||||
|
||||
func NeedsUpdate(secret *v1.Secret, cn ...string) bool {
|
||||
if secret.Annotations[static] == "true" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, cn := range cn {
|
||||
if secret.Annotations[cnPrefix+cn] == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getPrivateKey(secret *v1.Secret) (*ecdsa.PrivateKey, error) {
|
||||
keyBytes := secret.Data[v1.TLSPrivateKeyKey]
|
||||
if len(keyBytes) == 0 {
|
||||
return NewPrivateKey()
|
||||
}
|
||||
|
||||
privateKey, err := ParseECPrivateKeyPEM(keyBytes)
|
||||
if err == nil {
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
return NewPrivateKey()
|
||||
}
|
||||
|
||||
func Marshal(x509Cert *x509.Certificate, privateKey *ecdsa.PrivateKey) ([]byte, []byte, error) {
|
||||
certBlock := pem.Block{
|
||||
Type: CertificateBlockType,
|
||||
Bytes: x509Cert.Raw,
|
||||
}
|
||||
|
||||
keyBytes, err := x509.MarshalECPrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keyBlock := pem.Block{
|
||||
Type: ECPrivateKeyBlockType,
|
||||
Bytes: keyBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(&certBlock), pem.EncodeToMemory(&keyBlock), nil
|
||||
}
|
||||
|
||||
func NewPrivateKey() (*ecdsa.PrivateKey, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
}
|
13
go.mod
13
go.mod
@ -3,16 +3,7 @@ module github.com/rancher/dynamiclistener
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/hashicorp/golang-lru v0.5.1
|
||||
github.com/kisielk/gotool v1.0.0 // indirect
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||
github.com/rancher/wrangler-api v0.2.0
|
||||
github.com/sirupsen/logrus v1.4.1
|
||||
github.com/stretchr/testify v1.3.0 // indirect
|
||||
github.com/stripe/safesql v0.2.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284
|
||||
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c // indirect
|
||||
golang.org/x/sys v0.0.0-20190509141414-a5b02f93d862 // indirect
|
||||
golang.org/x/text v0.3.2 // indirect
|
||||
mvdan.cc/interfacer v0.0.0-20180901003855-c20040233aed // indirect
|
||||
mvdan.cc/lint v0.0.0-20170908181259-adc824a0674b // indirect
|
||||
k8s.io/api v0.0.0-20190409021203-6e4e0e4f393b
|
||||
)
|
||||
|
198
listener.go
Normal file
198
listener.go
Normal file
@ -0,0 +1,198 @@
|
||||
package dynamiclistener
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rancher/dynamiclistener/factory"
|
||||
"github.com/sirupsen/logrus"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
)
|
||||
|
||||
type TLSStorage interface {
|
||||
Get() (*v1.Secret, error)
|
||||
Update(secret *v1.Secret) error
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
CN string
|
||||
Organization []string
|
||||
TLSConfig tls.Config
|
||||
SANs []string
|
||||
}
|
||||
|
||||
func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, caKey crypto.Signer, config Config) (net.Listener, http.Handler, error) {
|
||||
if config.CN == "" {
|
||||
config.CN = "dynamic"
|
||||
}
|
||||
if len(config.Organization) == 0 {
|
||||
config.Organization = []string{"dynamic"}
|
||||
}
|
||||
|
||||
dynamicListener := &listener{
|
||||
factory: &factory.TLS{
|
||||
CACert: caCert,
|
||||
CAKey: caKey,
|
||||
CN: config.CN,
|
||||
Organization: config.Organization,
|
||||
},
|
||||
Listener: l,
|
||||
storage: &nonNil{storage: storage},
|
||||
sans: config.SANs,
|
||||
tlsConfig: config.TLSConfig,
|
||||
}
|
||||
dynamicListener.tlsConfig.GetCertificate = dynamicListener.getCertificate
|
||||
|
||||
return tls.NewListener(dynamicListener, &dynamicListener.tlsConfig), dynamicListener.cacheHandler(), nil
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
sync.RWMutex
|
||||
net.Listener
|
||||
|
||||
factory *factory.TLS
|
||||
storage TLSStorage
|
||||
version string
|
||||
tlsConfig tls.Config
|
||||
cert *tls.Certificate
|
||||
sans []string
|
||||
}
|
||||
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
addr := conn.RemoteAddr()
|
||||
if addr == nil {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr.String())
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to parse network %s: %v", addr.Network(), err)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
return conn, l.updateCert(host)
|
||||
}
|
||||
|
||||
func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if hello.ServerName != "" {
|
||||
if err := l.updateCert(hello.ServerName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return l.loadCert()
|
||||
}
|
||||
|
||||
func (l *listener) updateCert(cn string) error {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
secret, err := l.storage.Get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !factory.NeedsUpdate(secret, append(l.sans, cn)...) {
|
||||
return nil
|
||||
}
|
||||
|
||||
l.RUnlock()
|
||||
l.Lock()
|
||||
defer l.RLock()
|
||||
defer l.Unlock()
|
||||
|
||||
secret, updated, err := l.factory.AddCN(secret, append(l.sans, cn)...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := l.storage.Update(secret); err != nil {
|
||||
return err
|
||||
}
|
||||
// clear version to force cert reload
|
||||
l.version = ""
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *listener) loadCert() (*tls.Certificate, error) {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
secret, err := l.storage.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if l.cert != nil && l.version == secret.ResourceVersion {
|
||||
return l.cert, nil
|
||||
}
|
||||
|
||||
defer l.RLock()
|
||||
l.RUnlock()
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
secret, err = l.storage.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if l.cert != nil && l.version == secret.ResourceVersion {
|
||||
return l.cert, nil
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.cert = &cert
|
||||
return l.cert, nil
|
||||
}
|
||||
|
||||
func (l *listener) cacheHandler() http.Handler {
|
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
||||
h, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
h = req.Host
|
||||
}
|
||||
|
||||
ip := net.ParseIP(h)
|
||||
if len(ip) > 0 {
|
||||
l.updateCert(h)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type nonNil struct {
|
||||
sync.Mutex
|
||||
storage TLSStorage
|
||||
}
|
||||
|
||||
func (n *nonNil) Get() (*v1.Secret, error) {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
s, err := n.storage.Get()
|
||||
if err != nil || s == nil {
|
||||
return &v1.Secret{}, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (n *nonNil) Update(secret *v1.Secret) error {
|
||||
n.Lock()
|
||||
defer n.Unlock()
|
||||
|
||||
return n.storage.Update(secret)
|
||||
}
|
50
read.go
50
read.go
@ -1,50 +0,0 @@
|
||||
package dynamiclistener
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func ReadTLSConfig(userConfig *UserConfig) error {
|
||||
var err error
|
||||
|
||||
path := userConfig.CertPath
|
||||
|
||||
userConfig.CACerts, err = readPEM(filepath.Join(path, "cacerts.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userConfig.Key, err = readPEM(filepath.Join(path, "key.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userConfig.Cert, err = readPEM(filepath.Join(path, "cert.pem"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
valid := false
|
||||
if userConfig.Key != "" && userConfig.Cert != "" {
|
||||
valid = true
|
||||
} else if userConfig.Key == "" && userConfig.Cert == "" {
|
||||
valid = true
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid SSL configuration found, please set cert/key, cert/key/cacerts, cacerts only, or none")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readPEM(path string) (string, error) {
|
||||
content, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
}
|
46
redirect.go
Normal file
46
redirect.go
Normal file
@ -0,0 +1,46 @@
|
||||
package dynamiclistener
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Approach taken from letsencrypt, except manglePort is specific to us
|
||||
func HTTPRedirect(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(rw http.ResponseWriter, r *http.Request) {
|
||||
fmt.Println("!!!!!", r.URL.String(), r.Header)
|
||||
if r.Header.Get("x-Forwarded-Proto") == "https" ||
|
||||
r.Header.Get("x-Forwarded-Proto") == "wss" ||
|
||||
strings.HasPrefix(r.URL.Path, "/ping") ||
|
||||
strings.HasPrefix(r.URL.Path, "/health") {
|
||||
next.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
if r.Method != "GET" && r.Method != "HEAD" {
|
||||
http.Error(rw, "Use HTTPS", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
target := "https://" + manglePort(r.Host) + r.URL.RequestURI()
|
||||
http.Redirect(rw, r, target, http.StatusFound)
|
||||
})
|
||||
}
|
||||
|
||||
func manglePort(hostport string) string {
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
return hostport
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return hostport
|
||||
}
|
||||
|
||||
portInt = ((portInt / 1000) * 1000) + 443
|
||||
|
||||
return net.JoinHostPort(host, strconv.Itoa(portInt))
|
||||
}
|
556
server.go
556
server.go
@ -1,556 +0,0 @@
|
||||
package dynamiclistener
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
cert "github.com/rancher/dynamiclistener/cert"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
sync.Mutex
|
||||
|
||||
userConfig UserConfig
|
||||
listenConfigStorage ListenerConfigStorage
|
||||
tlsCert *tls.Certificate
|
||||
|
||||
ips map[string]bool
|
||||
domains map[string]bool
|
||||
cn string
|
||||
|
||||
listeners []net.Listener
|
||||
servers []*http.Server
|
||||
|
||||
activeCA *x509.Certificate
|
||||
activeCAKey crypto.Signer
|
||||
}
|
||||
|
||||
func NewServer(listenConfigStorage ListenerConfigStorage, config UserConfig) (ServerInterface, error) {
|
||||
s := &server{
|
||||
userConfig: config,
|
||||
listenConfigStorage: listenConfigStorage,
|
||||
cn: "cattle",
|
||||
}
|
||||
|
||||
s.ips = map[string]bool{}
|
||||
s.domains = map[string]bool{}
|
||||
|
||||
if err := s.userConfigure(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lc, err := listenConfigStorage.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, s.Update(lc)
|
||||
}
|
||||
|
||||
func (s *server) CACert() (string, error) {
|
||||
if s.userConfig.NoCACerts {
|
||||
return "", nil
|
||||
}
|
||||
if s.userConfig.CACerts != "" {
|
||||
return s.userConfig.CACerts, nil
|
||||
}
|
||||
return "", fmt.Errorf("ca cert not found")
|
||||
}
|
||||
|
||||
func marshalPrivateKey(privateKey crypto.Signer) (string, []byte, error) {
|
||||
var (
|
||||
keyType string
|
||||
bytes []byte
|
||||
err error
|
||||
)
|
||||
if key, ok := privateKey.(*ecdsa.PrivateKey); ok {
|
||||
keyType = cert.ECPrivateKeyBlockType
|
||||
bytes, err = x509.MarshalECPrivateKey(key)
|
||||
} else if key, ok := privateKey.(*rsa.PrivateKey); ok {
|
||||
keyType = cert.RSAPrivateKeyBlockType
|
||||
bytes = x509.MarshalPKCS1PrivateKey(key)
|
||||
} else {
|
||||
keyType = cert.PrivateKeyBlockType
|
||||
bytes, err = x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
}
|
||||
if err != nil {
|
||||
logrus.Errorf("Unable to marshal private key: %v", err)
|
||||
}
|
||||
return keyType, bytes, err
|
||||
}
|
||||
|
||||
func newPrivateKey() (crypto.Signer, error) {
|
||||
caKeyBytes, err := cert.MakeEllipticPrivateKeyPEM()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
caKeyIFace, err := cert.ParsePrivateKeyPEM(caKeyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return caKeyIFace.(crypto.Signer), nil
|
||||
}
|
||||
|
||||
func (s *server) save() (_err error) {
|
||||
defer func() {
|
||||
if _err != nil {
|
||||
logrus.Errorf("Saving cert error: %s", _err)
|
||||
}
|
||||
}()
|
||||
|
||||
certStr, err := certToString(s.tlsCert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg, err := s.listenConfigStorage.Get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.GeneratedCerts = map[string]string{s.cn: certStr}
|
||||
|
||||
_, err = s.listenConfigStorage.Set(cfg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *server) userConfigure() error {
|
||||
if s.userConfig.HTTPSPort == 0 {
|
||||
s.userConfig.HTTPSPort = 8443
|
||||
}
|
||||
|
||||
for _, d := range s.userConfig.Domains {
|
||||
s.domains[d] = true
|
||||
}
|
||||
|
||||
for _, ip := range s.userConfig.KnownIPs {
|
||||
if netIP := net.ParseIP(ip); netIP != nil {
|
||||
s.ips[ip] = true
|
||||
}
|
||||
}
|
||||
|
||||
if bindAddress := net.ParseIP(s.userConfig.BindAddress); bindAddress != nil {
|
||||
s.ips[s.userConfig.BindAddress] = true
|
||||
}
|
||||
|
||||
if s.activeCA == nil && s.activeCAKey == nil {
|
||||
if s.userConfig.CACerts != "" && s.userConfig.CAKey != "" {
|
||||
ca, err := cert.ParseCertsPEM([]byte(s.userConfig.CACerts))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key, err := cert.ParsePrivateKeyPEM([]byte(s.userConfig.CAKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.activeCA = ca[0]
|
||||
s.activeCAKey = key.(crypto.Signer)
|
||||
} else {
|
||||
ca, key, err := genCA()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.activeCA = ca
|
||||
s.activeCAKey = key
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func genCA() (*x509.Certificate, crypto.Signer, error) {
|
||||
caKey, err := newPrivateKey()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
caCert, err := cert.NewSelfSignedCACert(cert.Config{
|
||||
CommonName: "k3s-ca",
|
||||
Organization: []string{"k3s-org"},
|
||||
}, caKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return caCert, caKey, nil
|
||||
}
|
||||
|
||||
func (s *server) Update(status *ListenerStatus) (_err error) {
|
||||
s.Lock()
|
||||
defer func() {
|
||||
s.Unlock()
|
||||
if _err != nil {
|
||||
logrus.Errorf("Update cert error: %s", _err)
|
||||
}
|
||||
if s.tlsCert == nil {
|
||||
s.getCertificate(&tls.ClientHelloInfo{ServerName: "localhost"})
|
||||
}
|
||||
}()
|
||||
|
||||
certString := status.GeneratedCerts[s.cn]
|
||||
tlsCert, err := stringToCert(certString)
|
||||
if err != nil {
|
||||
logrus.Errorf("Update cert unable to convert string to cert: %s", err)
|
||||
s.tlsCert = nil
|
||||
}
|
||||
if tlsCert != nil {
|
||||
s.tlsCert = tlsCert
|
||||
for i, certBytes := range tlsCert.Certificate {
|
||||
parsedCert, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
logrus.Errorf("Update cert %d parse error: %s", i, err)
|
||||
s.tlsCert = nil
|
||||
break
|
||||
}
|
||||
isExpired := cert.IsCertExpired(parsedCert)
|
||||
if isExpired {
|
||||
logrus.Infof("certificate is about to expire")
|
||||
s.tlsCert = nil
|
||||
break
|
||||
}
|
||||
ips := map[string]bool{}
|
||||
for _, ip := range parsedCert.IPAddresses {
|
||||
ips[ip.String()] = true
|
||||
}
|
||||
|
||||
domains := map[string]bool{}
|
||||
for _, domain := range parsedCert.DNSNames {
|
||||
domains[domain] = true
|
||||
}
|
||||
|
||||
if !(reflect.DeepEqual(ips, s.ips) && reflect.DeepEqual(domains, s.domains)) {
|
||||
subset := true
|
||||
for ip := range s.ips {
|
||||
if !ips[ip] {
|
||||
subset = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if subset {
|
||||
for domain := range s.domains {
|
||||
if !domains[domain] {
|
||||
subset = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !subset {
|
||||
s.tlsCert = nil
|
||||
}
|
||||
for ip := range ips {
|
||||
s.ips[ip] = true
|
||||
}
|
||||
for domain := range domains {
|
||||
s.domains[domain] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.reload()
|
||||
}
|
||||
|
||||
func (s *server) shutdown() error {
|
||||
for _, listener := range s.listeners {
|
||||
if err := listener.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.listeners = nil
|
||||
|
||||
for _, server := range s.servers {
|
||||
go server.Shutdown(context.Background())
|
||||
}
|
||||
s.servers = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) reload() error {
|
||||
if len(s.listeners) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.shutdown(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.serveHTTPS(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) getCertificate(hello *tls.ClientHelloInfo) (_servingCert *tls.Certificate, _err error) {
|
||||
s.Lock()
|
||||
changed := false
|
||||
|
||||
defer func() {
|
||||
defer s.Unlock()
|
||||
|
||||
if _err != nil {
|
||||
logrus.Errorf("Get certificate error: %s", _err)
|
||||
return
|
||||
}
|
||||
|
||||
if changed {
|
||||
s.save()
|
||||
}
|
||||
}()
|
||||
|
||||
if hello.ServerName != "" && !s.domains[hello.ServerName] {
|
||||
s.tlsCert = nil
|
||||
s.domains[hello.ServerName] = true
|
||||
}
|
||||
|
||||
if s.tlsCert != nil {
|
||||
return s.tlsCert, nil
|
||||
}
|
||||
|
||||
ips := []net.IP{}
|
||||
for ipStr := range s.ips {
|
||||
if ip := net.ParseIP(ipStr); ip != nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
dnsNames := []string{}
|
||||
for domain := range s.domains {
|
||||
dnsNames = append(dnsNames, domain)
|
||||
}
|
||||
|
||||
cfg := cert.Config{
|
||||
CommonName: s.cn,
|
||||
Organization: s.activeCA.Subject.Organization,
|
||||
Usages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
AltNames: cert.AltNames{
|
||||
DNSNames: dnsNames,
|
||||
IPs: ips,
|
||||
},
|
||||
}
|
||||
|
||||
key, err := newPrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert, err := cert.NewSignedCert(cfg, key, s.activeCA, s.activeCAKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsCert := &tls.Certificate{
|
||||
Certificate: [][]byte{
|
||||
cert.Raw,
|
||||
},
|
||||
PrivateKey: key,
|
||||
}
|
||||
|
||||
changed = true
|
||||
s.tlsCert = tlsCert
|
||||
return tlsCert, nil
|
||||
}
|
||||
|
||||
func (s *server) cacheHandler(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
||||
h, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
h = req.Host
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
if !s.ips[h] {
|
||||
s.ips[h] = true
|
||||
s.tlsCert = nil
|
||||
}
|
||||
} else {
|
||||
if !s.domains[h] {
|
||||
s.domains[h] = true
|
||||
s.tlsCert = nil
|
||||
}
|
||||
}
|
||||
s.Unlock()
|
||||
|
||||
handler.ServeHTTP(resp, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *server) serveHTTPS() error {
|
||||
conf := &tls.Config{
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
GetCertificate: s.getCertificate,
|
||||
PreferServerCipherSuites: true,
|
||||
}
|
||||
|
||||
listener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPSPort, conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger := logrus.StandardLogger()
|
||||
server := &http.Server{
|
||||
Handler: s.cacheHandler(s.Handler()),
|
||||
ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags),
|
||||
}
|
||||
|
||||
s.servers = append(s.servers, server)
|
||||
s.startServer(listener, server)
|
||||
|
||||
if s.userConfig.HTTPPort > 0 {
|
||||
httpListener, err := s.newListener(s.userConfig.BindAddress, s.userConfig.HTTPPort, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Handler: s.cacheHandler(httpRedirect(s.Handler())),
|
||||
ErrorLog: log.New(logger.WriterLevel(logrus.DebugLevel), "", log.LstdFlags),
|
||||
}
|
||||
|
||||
s.servers = append(s.servers, httpServer)
|
||||
s.startServer(httpListener, httpServer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Approach taken from letsencrypt, except manglePort is specific to us
|
||||
func httpRedirect(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("x-Forwarded-Proto") == "https" ||
|
||||
strings.HasPrefix(r.URL.Path, "/ping") ||
|
||||
strings.HasPrefix(r.URL.Path, "/health") {
|
||||
next.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
if r.Method != "GET" && r.Method != "HEAD" {
|
||||
http.Error(rw, "Use HTTPS", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
target := "https://" + manglePort(r.Host) + r.URL.RequestURI()
|
||||
http.Redirect(rw, r, target, http.StatusFound)
|
||||
})
|
||||
}
|
||||
|
||||
func manglePort(hostport string) string {
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
return hostport
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return hostport
|
||||
}
|
||||
|
||||
portInt = ((portInt / 1000) * 1000) + 443
|
||||
|
||||
return net.JoinHostPort(host, strconv.Itoa(portInt))
|
||||
}
|
||||
|
||||
func (s *server) startServer(listener net.Listener, server *http.Server) {
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil {
|
||||
logrus.Errorf("server on %v returned err: %v", listener.Addr(), err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *server) Handler() http.Handler {
|
||||
return s.userConfig.Handler
|
||||
}
|
||||
|
||||
func (s *server) newListener(ip string, port int, config *tls.Config) (net.Listener, error) {
|
||||
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||
l, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l = tcpKeepAliveListener{l.(*net.TCPListener)}
|
||||
|
||||
if config != nil {
|
||||
l = tls.NewListener(l, config)
|
||||
}
|
||||
|
||||
s.listeners = append(s.listeners, l)
|
||||
logrus.Info("Listening on ", addr)
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func stringToCert(certString string) (*tls.Certificate, error) {
|
||||
parts := strings.Split(certString, "#")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("Unable to split cert into two parts")
|
||||
}
|
||||
|
||||
certPart, keyPart := parts[0], parts[1]
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(keyPart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := cert.ParsePrivateKeyPEM(keyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certBytes, err := base64.StdEncoding.DecodeString(certPart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{certBytes},
|
||||
PrivateKey: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func certToString(cert *tls.Certificate) (string, error) {
|
||||
keyType, keyBytes, err := marshalPrivateKey(cert.PrivateKey.(crypto.Signer))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKeyPemBlock := &pem.Block{
|
||||
Type: keyType,
|
||||
Bytes: keyBytes,
|
||||
}
|
||||
pemBytes := pem.EncodeToMemory(privateKeyPemBlock)
|
||||
|
||||
certString := base64.StdEncoding.EncodeToString(cert.Certificate[0])
|
||||
keyString := base64.StdEncoding.EncodeToString(pemBytes)
|
||||
return certString + "#" + keyString, nil
|
||||
}
|
||||
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tc.SetKeepAlive(true)
|
||||
tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
return tc, nil
|
||||
}
|
80
server/server.go
Normal file
80
server/server.go
Normal file
@ -0,0 +1,80 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/rancher/dynamiclistener"
|
||||
"github.com/rancher/dynamiclistener/factory"
|
||||
"github.com/rancher/dynamiclistener/storage/memory"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func ListenAndServe(ctx context.Context, httpsPort, httpPort int, handler http.Handler) error {
|
||||
var (
|
||||
// https listener will change this if http is enabled
|
||||
targetHandler = handler
|
||||
)
|
||||
if httpsPort > 0 {
|
||||
caCert, caKey, err := factory.LoadOrGenCA()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tlsTCPListener, err := dynamiclistener.NewTCPListener("0.0.0.0", httpsPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dynListener, dynHandler, err := dynamiclistener.NewListener(tlsTCPListener, memory.New(), caCert, caKey, dynamiclistener.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetHandler = wrapHandler(dynHandler, handler)
|
||||
tlsServer := http.Server{
|
||||
Handler: targetHandler,
|
||||
}
|
||||
targetHandler = dynamiclistener.HTTPRedirect(targetHandler)
|
||||
|
||||
go func() {
|
||||
logrus.Infof("Listening on 0.0.0.0:%d", httpsPort)
|
||||
err := tlsServer.Serve(dynListener)
|
||||
if err != http.ErrServerClosed && err != nil {
|
||||
logrus.Fatalf("https server failed: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
tlsServer.Shutdown(context.Background())
|
||||
}()
|
||||
}
|
||||
|
||||
if httpPort > 0 {
|
||||
httpServer := http.Server{
|
||||
Addr: fmt.Sprintf("0.0.0.0:%d", httpPort),
|
||||
Handler: targetHandler,
|
||||
}
|
||||
go func() {
|
||||
logrus.Infof("Listening on 0.0.0.0:%d", httpPort)
|
||||
err := httpServer.ListenAndServe()
|
||||
if err != http.ErrServerClosed && err != nil {
|
||||
logrus.Fatalf("http server failed: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
httpServer.Shutdown(context.Background())
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func wrapHandler(handler http.Handler, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
handler.ServeHTTP(rw, req)
|
||||
next.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
42
storage/file/file.go
Normal file
42
storage/file/file.go
Normal file
@ -0,0 +1,42 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/rancher/dynamiclistener"
|
||||
"k8s.io/api/core/v1"
|
||||
"os"
|
||||
)
|
||||
|
||||
func New(file string) dynamiclistener.TLSStorage {
|
||||
return &storage{
|
||||
file: file,
|
||||
}
|
||||
}
|
||||
|
||||
type storage struct {
|
||||
file string
|
||||
}
|
||||
|
||||
func (s *storage) Get() (*v1.Secret, error) {
|
||||
f, err := os.Open(s.file)
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
secret := v1.Secret{}
|
||||
return &secret, json.NewDecoder(f).Decode(&secret)
|
||||
}
|
||||
|
||||
func (s *storage) Update(secret *v1.Secret) error {
|
||||
f, err := os.Create(s.file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return json.NewEncoder(f).Encode(secret)
|
||||
}
|
||||
|
109
storage/kubernetes/controller.go
Normal file
109
storage/kubernetes/controller.go
Normal file
@ -0,0 +1,109 @@
|
||||
package kubernetes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rancher/dynamiclistener"
|
||||
"github.com/rancher/wrangler-api/pkg/generated/controllers/core"
|
||||
v1controller "github.com/rancher/wrangler-api/pkg/generated/controllers/core/v1"
|
||||
"github.com/rancher/wrangler/pkg/start"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
"k8s.io/apimachinery/pkg/api/equality"
|
||||
)
|
||||
|
||||
type CoreGetter func() *core.Factory
|
||||
|
||||
func New(ctx context.Context, core CoreGetter, namespace, name string, backing dynamiclistener.TLSStorage) dynamiclistener.TLSStorage {
|
||||
storage := &storage{
|
||||
name: name,
|
||||
namespace: namespace,
|
||||
storage: backing,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
// lazy init
|
||||
go func() {
|
||||
for {
|
||||
core := core()
|
||||
if core != nil {
|
||||
storage.init(core.Core().V1().Secret())
|
||||
start.All(ctx, 5, core)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
type storage struct {
|
||||
sync.Mutex
|
||||
|
||||
namespace, name string
|
||||
storage dynamiclistener.TLSStorage
|
||||
secrets v1controller.SecretClient
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s *storage) init(secrets v1controller.SecretController) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
secrets.OnChange(s.ctx, "tls-storage", func(key string, secret *v1.Secret) (*v1.Secret, error) {
|
||||
if secret == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if secret.Namespace == s.namespace && secret.Name == s.name {
|
||||
if err := s.Update(secret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
})
|
||||
s.secrets = secrets
|
||||
}
|
||||
|
||||
func (s *storage) Get() (*v1.Secret, error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
return s.storage.Get()
|
||||
}
|
||||
|
||||
func (s *storage) Update(secret *v1.Secret) (err error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if s.secrets != nil {
|
||||
if secret.UID == "" {
|
||||
secret.Name = s.name
|
||||
secret.Namespace = s.namespace
|
||||
secret, err = s.secrets.Create(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
existingSecret, err := s.storage.Get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !equality.Semantic.DeepEqual(secret.Data, existingSecret.Data) {
|
||||
secret, err = s.secrets.Update(secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.storage.Update(secret)
|
||||
}
|
42
storage/memory/memory.go
Normal file
42
storage/memory/memory.go
Normal file
@ -0,0 +1,42 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"github.com/rancher/dynamiclistener"
|
||||
v1 "k8s.io/api/core/v1"
|
||||
)
|
||||
|
||||
func New() dynamiclistener.TLSStorage {
|
||||
return &memory{}
|
||||
}
|
||||
|
||||
func NewBacked(storage dynamiclistener.TLSStorage) dynamiclistener.TLSStorage {
|
||||
return &memory{storage: storage}
|
||||
}
|
||||
|
||||
type memory struct {
|
||||
storage dynamiclistener.TLSStorage
|
||||
secret *v1.Secret
|
||||
}
|
||||
|
||||
func (m *memory) Get() (*v1.Secret, error) {
|
||||
if m.secret == nil && m.storage != nil {
|
||||
secret, err := m.storage.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.secret = secret
|
||||
}
|
||||
|
||||
return m.secret, nil
|
||||
}
|
||||
|
||||
func (m *memory) Update(secret *v1.Secret) error {
|
||||
if m.storage != nil {
|
||||
if err := m.storage.Update(secret); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.secret = secret
|
||||
return nil
|
||||
}
|
38
tcp.go
Normal file
38
tcp.go
Normal file
@ -0,0 +1,38 @@
|
||||
package dynamiclistener
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewTCPListener(ip string, port int) (net.Listener, error) {
|
||||
l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", ip, port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tcpListener, ok := l.(*net.TCPListener)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("wrong listener type: %v", reflect.TypeOf(tcpListener))
|
||||
}
|
||||
|
||||
return tcpKeepAliveListener{
|
||||
TCPListener: tcpListener,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tc.SetKeepAlive(true)
|
||||
tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
return tc, nil
|
||||
}
|
63
types.go
63
types.go
@ -1,63 +0,0 @@
|
||||
package dynamiclistener
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ListenerConfigStorage interface {
|
||||
Set(*ListenerStatus) (*ListenerStatus, error)
|
||||
Get() (*ListenerStatus, error)
|
||||
}
|
||||
|
||||
type ServerInterface interface {
|
||||
Update(status *ListenerStatus) error
|
||||
CACert() (string, error)
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
// Required fields
|
||||
|
||||
Handler http.Handler
|
||||
HTTPPort int
|
||||
HTTPSPort int
|
||||
CertPath string
|
||||
|
||||
// Optional fields
|
||||
|
||||
KnownIPs []string
|
||||
Domains []string
|
||||
Mode string
|
||||
NoCACerts bool
|
||||
CACerts string
|
||||
CAKey string
|
||||
Cert string
|
||||
Key string
|
||||
BindAddress string
|
||||
}
|
||||
|
||||
type ListenerStatus struct {
|
||||
Revision string `json:"revision,omitempty"`
|
||||
CACert string `json:"caCert,omitempty"`
|
||||
CAKey string `json:"caKey,omitempty"`
|
||||
GeneratedCerts map[string]string `json:"generatedCerts" norman:"nocreate,noupdate"`
|
||||
KnownIPs map[string]bool `json:"knownIps" norman:"nocreate,noupdate"`
|
||||
}
|
||||
|
||||
func (l *ListenerStatus) DeepCopyInto(t *ListenerStatus) {
|
||||
t.Revision = l.Revision
|
||||
t.CACert = l.CACert
|
||||
t.CAKey = l.CAKey
|
||||
t.GeneratedCerts = copyMap(t.GeneratedCerts)
|
||||
t.KnownIPs = map[string]bool{}
|
||||
for k, v := range l.KnownIPs {
|
||||
t.KnownIPs[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func copyMap(m map[string]string) map[string]string {
|
||||
ret := map[string]string{}
|
||||
for k, v := range m {
|
||||
ret[k] = v
|
||||
}
|
||||
return ret
|
||||
}
|
Loading…
Reference in New Issue
Block a user