diff --git a/server.go b/server.go index 992d8aa..896acdf 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,8 @@ package dynamiclistener import ( "bytes" "context" + "crypto" + "crypto/ecdsa" "crypto/md5" "crypto/rsa" "crypto/tls" @@ -46,7 +48,7 @@ type server struct { // dynamic config change on refresh activeCert *tls.Certificate activeCA *x509.Certificate - activeCAKey *rsa.PrivateKey + activeCAKey crypto.Signer activeCAKeyString string domains map[string]bool } @@ -91,6 +93,40 @@ func (s *server) CACert() (string, error) { return status.CACert, nil } +func marshalPrivateKey(privateKey crypto.Signer) (string, []byte, error) { + var ( + keyType string + bytes []byte + err error + ) + if key, ok := privateKey.(*ecdsa.PrivateKey); ok { + keyType = cert.ECPrivateKeyBlockType + bytes, err = x509.MarshalECPrivateKey(key) + } else if key, ok := privateKey.(*rsa.PrivateKey); ok { + keyType = cert.RSAPrivateKeyBlockType + bytes = x509.MarshalPKCS1PrivateKey(key) + } else { + keyType = cert.PrivateKeyBlockType + bytes, err = x509.MarshalPKCS8PrivateKey(privateKey) + } + if err != nil { + logrus.Errorf("Unable to marshal private key: %v", err) + } + return keyType, bytes, err +} + +func newPrivateKey() (crypto.Signer, error) { + caKeyBytes, err := cert.MakeEllipticPrivateKeyPEM() + if err != nil { + return nil, err + } + caKeyIFace, err := cert.ParsePrivateKeyPEM(caKeyBytes) + if err != nil { + return nil, err + } + return caKeyIFace.(crypto.Signer), nil +} + func (s *server) save() { if s.activeCert != nil { return @@ -114,7 +150,10 @@ func (s *server) save() { } for key, cert := range s.certs { - certStr := certToString(cert) + certStr, err := certToString(cert) + if err != nil { + continue + } if cfg.GeneratedCerts[key] != certStr { cfg.GeneratedCerts[key] = certStr changed = true @@ -139,9 +178,14 @@ func (s *server) save() { } caKeyBuffer := bytes.Buffer{} + keyType, keyBytes, err := marshalPrivateKey(s.activeCAKey) + if err != nil { + return + } + if err := pem.Encode(&caKeyBuffer, &pem.Block{ - Type: cert.RSAPrivateKeyBlockType, - Bytes: x509.MarshalPKCS1PrivateKey(s.activeCAKey), + Type: keyType, + Bytes: keyBytes, }); err != nil { return } @@ -198,8 +242,8 @@ func (s *server) userConfigure() error { return nil } -func genCA() (*x509.Certificate, *rsa.PrivateKey, error) { - caKey, err := cert.NewPrivateKey() +func genCA() (*x509.Certificate, crypto.Signer, error) { + caKey, err := newPrivateKey() if err != nil { return nil, nil, err } @@ -225,7 +269,7 @@ func (s *server) Update(status *ListenerStatus) error { s.Unlock() return err } - s.activeCAKey = cert.PrivateKey.(*rsa.PrivateKey) + s.activeCAKey = cert.PrivateKey.(crypto.Signer) s.activeCAKeyString = status.CAKey x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) @@ -398,7 +442,7 @@ func (s *server) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, e }, } - key, err := cert.NewPrivateKey() + key, err := newPrivateKey() if err != nil { return nil, err } @@ -603,32 +647,36 @@ func stringToCert(certString string) *tls.Certificate { return nil } - cert, key := parts[0], parts[1] - keyBytes, err := base64.StdEncoding.DecodeString(key) + certPart, keyPart := parts[0], parts[1] + keyBytes, err := base64.StdEncoding.DecodeString(keyPart) if err != nil { return nil } - rsaKey, err := x509.ParsePKCS1PrivateKey(keyBytes) + key, err := cert.ParsePrivateKeyPEM(keyBytes) if err != nil { return nil } - certBytes, err := base64.StdEncoding.DecodeString(cert) + certBytes, err := base64.StdEncoding.DecodeString(certPart) if err != nil { return nil } return &tls.Certificate{ Certificate: [][]byte{certBytes}, - PrivateKey: rsaKey, + PrivateKey: key, } } -func certToString(cert *tls.Certificate) string { +func certToString(cert *tls.Certificate) (string, error) { + _, keyBytes, err := marshalPrivateKey(cert.PrivateKey.(crypto.Signer)) + if err != nil { + return "", err + } certString := base64.StdEncoding.EncodeToString(cert.Certificate[0]) - keyString := base64.StdEncoding.EncodeToString(x509.MarshalPKCS1PrivateKey(cert.PrivateKey.(*rsa.PrivateKey))) - return certString + "#" + keyString + keyString := base64.StdEncoding.EncodeToString(keyBytes) + return certString + "#" + keyString, nil } type tcpKeepAliveListener struct {