mirror of
				https://github.com/k3s-io/kubernetes.git
				synced 2025-10-31 05:40:42 +00:00 
			
		
		
		
	Add --tls-sni-cert-key to the apiserver for SNI support
This commit is contained in:
		| @@ -502,9 +502,19 @@ func InitializeTLS(kc *componentconfig.KubeletConfiguration) (*server.TLSOptions | |||||||
| 		kc.TLSCertFile = path.Join(kc.CertDirectory, "kubelet.crt") | 		kc.TLSCertFile = path.Join(kc.CertDirectory, "kubelet.crt") | ||||||
| 		kc.TLSPrivateKeyFile = path.Join(kc.CertDirectory, "kubelet.key") | 		kc.TLSPrivateKeyFile = path.Join(kc.CertDirectory, "kubelet.key") | ||||||
| 		if !certutil.CanReadCertOrKey(kc.TLSCertFile, kc.TLSPrivateKeyFile) { | 		if !certutil.CanReadCertOrKey(kc.TLSCertFile, kc.TLSPrivateKeyFile) { | ||||||
| 			if err := certutil.GenerateSelfSignedCert(nodeutil.GetHostname(kc.HostnameOverride), kc.TLSCertFile, kc.TLSPrivateKeyFile, nil, nil); err != nil { | 			cert, key, err := certutil.GenerateSelfSignedCertKey(nodeutil.GetHostname(kc.HostnameOverride), nil, nil) | ||||||
|  | 			if err != nil { | ||||||
| 				return nil, fmt.Errorf("unable to generate self signed cert: %v", err) | 				return nil, fmt.Errorf("unable to generate self signed cert: %v", err) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			if err := certutil.WriteCert(kc.TLSCertFile, cert); err != nil { | ||||||
|  | 				return nil, err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if err := certutil.WriteKey(kc.TLSPrivateKeyFile, key); err != nil { | ||||||
|  | 				return nil, err | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			glog.V(4).Infof("Using self-signed cert (%s, %s)", kc.TLSCertFile, kc.TLSPrivateKeyFile) | 			glog.V(4).Infof("Using self-signed cert (%s, %s)", kc.TLSCertFile, kc.TLSPrivateKeyFile) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -555,6 +555,7 @@ tls-ca-file | |||||||
| tls-cert-file | tls-cert-file | ||||||
| tls-private-key-file | tls-private-key-file | ||||||
| to-version | to-version | ||||||
|  | tls-sni-cert-key | ||||||
| token-auth-file | token-auth-file | ||||||
| ttl-keys-prefix | ttl-keys-prefix | ||||||
| ttl-secs | ttl-secs | ||||||
|   | |||||||
| @@ -110,7 +110,7 @@ type Config struct { | |||||||
| 	// same value for this field. (Numbers > 1 currently untested.) | 	// same value for this field. (Numbers > 1 currently untested.) | ||||||
| 	MasterCount int | 	MasterCount int | ||||||
|  |  | ||||||
| 	SecureServingInfo   *ServingInfo | 	SecureServingInfo   *SecureServingInfo | ||||||
| 	InsecureServingInfo *ServingInfo | 	InsecureServingInfo *ServingInfo | ||||||
|  |  | ||||||
| 	// The port on PublicAddress where a read-write server will be installed. | 	// The port on PublicAddress where a read-write server will be installed. | ||||||
| @@ -177,17 +177,36 @@ type Config struct { | |||||||
| type ServingInfo struct { | type ServingInfo struct { | ||||||
| 	// BindAddress is the ip:port to serve on | 	// BindAddress is the ip:port to serve on | ||||||
| 	BindAddress string | 	BindAddress string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type SecureServingInfo struct { | ||||||
|  | 	ServingInfo | ||||||
|  |  | ||||||
| 	// ServerCert is the TLS cert info for serving secure traffic | 	// ServerCert is the TLS cert info for serving secure traffic | ||||||
| 	ServerCert CertInfo | 	ServerCert GeneratableKeyCert | ||||||
|  | 	// SNICerts are named CertKeys for serving secure traffic with SNI support. | ||||||
|  | 	SNICerts []NamedCertKey | ||||||
| 	// ClientCA is the certificate bundle for all the signers that you'll recognize for incoming client certificates | 	// ClientCA is the certificate bundle for all the signers that you'll recognize for incoming client certificates | ||||||
| 	ClientCA string | 	ClientCA string | ||||||
| } | } | ||||||
|  |  | ||||||
| type CertInfo struct { | type CertKey struct { | ||||||
| 	// CertFile is a file containing a PEM-encoded certificate | 	// CertFile is a file containing a PEM-encoded certificate | ||||||
| 	CertFile string | 	CertFile string | ||||||
| 	// KeyFile is a file containing a PEM-encoded private key for the certificate specified by CertFile | 	// KeyFile is a file containing a PEM-encoded private key for the certificate specified by CertFile | ||||||
| 	KeyFile string | 	KeyFile string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type NamedCertKey struct { | ||||||
|  | 	CertKey | ||||||
|  |  | ||||||
|  | 	// Names is a list of domain patterns: fully qualified domain names, possibly prefixed with | ||||||
|  | 	// wildcard segments. | ||||||
|  | 	Names []string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GeneratableKeyCert struct { | ||||||
|  | 	CertKey | ||||||
| 	// Generate indicates that the cert/key pair should be generated if its not present. | 	// Generate indicates that the cert/key pair should be generated if its not present. | ||||||
| 	Generate bool | 	Generate bool | ||||||
| } | } | ||||||
| @@ -248,12 +267,17 @@ func (c *Config) ApplyOptions(options *options.ServerRunOptions) *Config { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if options.SecurePort > 0 { | 	if options.SecurePort > 0 { | ||||||
| 		secureServingInfo := &ServingInfo{ | 		secureServingInfo := &SecureServingInfo{ | ||||||
|  | 			ServingInfo: ServingInfo{ | ||||||
| 				BindAddress: net.JoinHostPort(options.BindAddress.String(), strconv.Itoa(options.SecurePort)), | 				BindAddress: net.JoinHostPort(options.BindAddress.String(), strconv.Itoa(options.SecurePort)), | ||||||
| 			ServerCert: CertInfo{ | 			}, | ||||||
|  | 			ServerCert: GeneratableKeyCert{ | ||||||
|  | 				CertKey: CertKey{ | ||||||
| 					CertFile: options.TLSCertFile, | 					CertFile: options.TLSCertFile, | ||||||
| 					KeyFile:  options.TLSPrivateKeyFile, | 					KeyFile:  options.TLSPrivateKeyFile, | ||||||
| 				}, | 				}, | ||||||
|  | 			}, | ||||||
|  | 			SNICerts: []NamedCertKey{}, | ||||||
| 			ClientCA: options.ClientCAFile, | 			ClientCA: options.ClientCAFile, | ||||||
| 		} | 		} | ||||||
| 		if options.TLSCertFile == "" && options.TLSPrivateKeyFile == "" { | 		if options.TLSCertFile == "" && options.TLSPrivateKeyFile == "" { | ||||||
| @@ -262,6 +286,17 @@ func (c *Config) ApplyOptions(options *options.ServerRunOptions) *Config { | |||||||
| 			secureServingInfo.ServerCert.KeyFile = path.Join(options.CertDirectory, "apiserver.key") | 			secureServingInfo.ServerCert.KeyFile = path.Join(options.CertDirectory, "apiserver.key") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		secureServingInfo.SNICerts = nil | ||||||
|  | 		for _, nkc := range options.SNICertKeys { | ||||||
|  | 			secureServingInfo.SNICerts = append(secureServingInfo.SNICerts, NamedCertKey{ | ||||||
|  | 				CertKey: CertKey{ | ||||||
|  | 					KeyFile:  nkc.KeyFile, | ||||||
|  | 					CertFile: nkc.CertFile, | ||||||
|  | 				}, | ||||||
|  | 				Names: nkc.Names, | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		c.SecureServingInfo = secureServingInfo | 		c.SecureServingInfo = secureServingInfo | ||||||
| 		c.ReadWritePort = options.SecurePort | 		c.ReadWritePort = options.SecurePort | ||||||
| 	} | 	} | ||||||
| @@ -434,9 +469,16 @@ func (c completedConfig) MaybeGenerateServingCerts() error { | |||||||
| 		alternateIPs := []net.IP{c.ServiceReadWriteIP} | 		alternateIPs := []net.IP{c.ServiceReadWriteIP} | ||||||
| 		alternateDNS := []string{"kubernetes.default.svc", "kubernetes.default", "kubernetes", "localhost"} | 		alternateDNS := []string{"kubernetes.default.svc", "kubernetes.default", "kubernetes", "localhost"} | ||||||
|  |  | ||||||
| 		if err := certutil.GenerateSelfSignedCert(c.PublicAddress.String(), c.SecureServingInfo.ServerCert.CertFile, c.SecureServingInfo.ServerCert.KeyFile, alternateIPs, alternateDNS); err != nil { | 		if cert, key, err := certutil.GenerateSelfSignedCertKey(c.PublicAddress.String(), alternateIPs, alternateDNS); err != nil { | ||||||
| 			return fmt.Errorf("Unable to generate self signed cert: %v", err) | 			return fmt.Errorf("unable to generate self signed cert: %v", err) | ||||||
| 		} else { | 		} else { | ||||||
|  | 			if err := certutil.WriteCert(c.SecureServingInfo.ServerCert.CertFile, cert); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if err := certutil.WriteKey(c.SecureServingInfo.ServerCert.KeyFile, key); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
| 			glog.Infof("Generated self-signed cert (%s, %s)", c.SecureServingInfo.ServerCert.CertFile, c.SecureServingInfo.ServerCert.KeyFile) | 			glog.Infof("Generated self-signed cert (%s, %s)", c.SecureServingInfo.ServerCert.CertFile, c.SecureServingInfo.ServerCert.KeyFile) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -109,7 +109,7 @@ type GenericAPIServer struct { | |||||||
| 	// The registered APIs | 	// The registered APIs | ||||||
| 	HandlerContainer *genericmux.APIContainer | 	HandlerContainer *genericmux.APIContainer | ||||||
|  |  | ||||||
| 	SecureServingInfo   *ServingInfo | 	SecureServingInfo   *SecureServingInfo | ||||||
| 	InsecureServingInfo *ServingInfo | 	InsecureServingInfo *ServingInfo | ||||||
|  |  | ||||||
| 	// numerical ports, set after listening | 	// numerical ports, set after listening | ||||||
|   | |||||||
| @@ -118,6 +118,7 @@ type ServerRunOptions struct { | |||||||
| 	TLSCAFile              string | 	TLSCAFile              string | ||||||
| 	TLSCertFile            string | 	TLSCertFile            string | ||||||
| 	TLSPrivateKeyFile      string | 	TLSPrivateKeyFile      string | ||||||
|  | 	SNICertKeys            []config.NamedCertKey | ||||||
| 	TokenAuthFile          string | 	TokenAuthFile          string | ||||||
| 	EnableAnyToken         bool | 	EnableAnyToken         bool | ||||||
| 	WatchCacheSizes        []string | 	WatchCacheSizes        []string | ||||||
| @@ -488,13 +489,22 @@ func (s *ServerRunOptions) AddUniversalFlags(fs *pflag.FlagSet) { | |||||||
| 		"Controllers. This must be a valid PEM-encoded CA bundle.") | 		"Controllers. This must be a valid PEM-encoded CA bundle.") | ||||||
|  |  | ||||||
| 	fs.StringVar(&s.TLSCertFile, "tls-cert-file", s.TLSCertFile, ""+ | 	fs.StringVar(&s.TLSCertFile, "tls-cert-file", s.TLSCertFile, ""+ | ||||||
| 		"File containing x509 Certificate for HTTPS. (CA cert, if any, concatenated "+ | 		"File containing the default x509 Certificate for HTTPS. (CA cert, if any, concatenated "+ | ||||||
| 		"after server cert). If HTTPS serving is enabled, and --tls-cert-file and "+ | 		"after server cert). If HTTPS serving is enabled, and --tls-cert-file and "+ | ||||||
| 		"--tls-private-key-file are not provided, a self-signed certificate and key "+ | 		"--tls-private-key-file are not provided, a self-signed certificate and key "+ | ||||||
| 		"are generated for the public address and saved to /var/run/kubernetes.") | 		"are generated for the public address and saved to /var/run/kubernetes.") | ||||||
|  |  | ||||||
| 	fs.StringVar(&s.TLSPrivateKeyFile, "tls-private-key-file", s.TLSPrivateKeyFile, | 	fs.StringVar(&s.TLSPrivateKeyFile, "tls-private-key-file", s.TLSPrivateKeyFile, | ||||||
| 		"File containing x509 private key matching --tls-cert-file.") | 		"File containing the default x509 private key matching --tls-cert-file.") | ||||||
|  |  | ||||||
|  | 	fs.Var(config.NewNamedCertKeyArray(&s.SNICertKeys), "tls-sni-cert-key", ""+ | ||||||
|  | 		"A pair of x509 certificate and private key file paths, optionally suffixed with a list of "+ | ||||||
|  | 		"domain patterns which are fully qualified domain names, possibly with prefixed wildcard "+ | ||||||
|  | 		"segments. If no domain patterns are provided, the names of the certificate are "+ | ||||||
|  | 		"extracted. Non-wildcard matches trump over wildcard matches, explicit domain patterns "+ | ||||||
|  | 		"trump over extracted names. For multiple key/certificate pairs, use the "+ | ||||||
|  | 		"--tls-sni-cert-key multiple times. "+ | ||||||
|  | 		"Examples: \"example.key,example.crt\" or \"*.foo.com,foo.com:foo.key,foo.crt\".") | ||||||
|  |  | ||||||
| 	fs.StringVar(&s.TokenAuthFile, "token-auth-file", s.TokenAuthFile, ""+ | 	fs.StringVar(&s.TokenAuthFile, "token-auth-file", s.TokenAuthFile, ""+ | ||||||
| 		"If set, the file that will be used to secure the secure port of the API server "+ | 		"If set, the file that will be used to secure the secure port of the API server "+ | ||||||
|   | |||||||
| @@ -40,11 +40,17 @@ const ( | |||||||
| // be loaded or the initial listen call fails. The actual server loop (stoppable by closing | // be loaded or the initial listen call fails. The actual server loop (stoppable by closing | ||||||
| // stopCh) runs in a go routine, i.e. serveSecurely does not block. | // stopCh) runs in a go routine, i.e. serveSecurely does not block. | ||||||
| func (s *GenericAPIServer) serveSecurely(stopCh <-chan struct{}) error { | func (s *GenericAPIServer) serveSecurely(stopCh <-chan struct{}) error { | ||||||
|  | 	namedCerts, err := getNamedCertificateMap(s.SecureServingInfo.SNICerts) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("unable to load SNI certificates: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	secureServer := &http.Server{ | 	secureServer := &http.Server{ | ||||||
| 		Addr:           s.SecureServingInfo.BindAddress, | 		Addr:           s.SecureServingInfo.BindAddress, | ||||||
| 		Handler:        s.Handler, | 		Handler:        s.Handler, | ||||||
| 		MaxHeaderBytes: 1 << 20, | 		MaxHeaderBytes: 1 << 20, | ||||||
| 		TLSConfig: &tls.Config{ | 		TLSConfig: &tls.Config{ | ||||||
|  | 			NameToCertificate: namedCerts, | ||||||
| 			// Can't use SSLv3 because of POODLE and BEAST | 			// Can't use SSLv3 because of POODLE and BEAST | ||||||
| 			// Can't use TLSv1.0 because of POODLE and BEAST using CBC cipher | 			// Can't use TLSv1.0 because of POODLE and BEAST using CBC cipher | ||||||
| 			// Can't use TLSv1.1 because of RC4 cipher usage | 			// Can't use TLSv1.1 because of RC4 cipher usage | ||||||
| @@ -54,7 +60,6 @@ func (s *GenericAPIServer) serveSecurely(stopCh <-chan struct{}) error { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var err error |  | ||||||
| 	if len(s.SecureServingInfo.ServerCert.CertFile) != 0 || len(s.SecureServingInfo.ServerCert.KeyFile) != 0 { | 	if len(s.SecureServingInfo.ServerCert.CertFile) != 0 || len(s.SecureServingInfo.ServerCert.KeyFile) != 0 { | ||||||
| 		secureServer.TLSConfig.Certificates = make([]tls.Certificate, 1) | 		secureServer.TLSConfig.Certificates = make([]tls.Certificate, 1) | ||||||
| 		secureServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(s.SecureServingInfo.ServerCert.CertFile, s.SecureServingInfo.ServerCert.KeyFile) | 		secureServer.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(s.SecureServingInfo.ServerCert.CertFile, s.SecureServingInfo.ServerCert.KeyFile) | ||||||
| @@ -63,6 +68,14 @@ func (s *GenericAPIServer) serveSecurely(stopCh <-chan struct{}) error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// append all named certs. Otherwise, the go tls stack will think no SNI processing | ||||||
|  | 	// is necessary because there is only one cert anyway. | ||||||
|  | 	// Moreover, if ServerCert.CertFile/ServerCert.KeyFile are not set, the first SNI | ||||||
|  | 	// cert will become the default cert. That's what we expect anyway. | ||||||
|  | 	for _, c := range namedCerts { | ||||||
|  | 		secureServer.TLSConfig.Certificates = append(secureServer.TLSConfig.Certificates, *c) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if len(s.SecureServingInfo.ClientCA) > 0 { | 	if len(s.SecureServingInfo.ClientCA) > 0 { | ||||||
| 		clientCAs, err := certutil.NewPool(s.SecureServingInfo.ClientCA) | 		clientCAs, err := certutil.NewPool(s.SecureServingInfo.ClientCA) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|   | |||||||
							
								
								
									
										506
									
								
								pkg/genericapiserver/serve_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										506
									
								
								pkg/genericapiserver/serve_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,506 @@ | |||||||
|  | /* | ||||||
|  | Copyright 2016 The Kubernetes Authors. | ||||||
|  |  | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | package genericapiserver | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"crypto/x509" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net" | ||||||
|  | 	"os" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	utilcert "k8s.io/kubernetes/pkg/util/cert" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type TestCertSpec struct { | ||||||
|  | 	host       string | ||||||
|  | 	names, ips []string // in certificate | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type NamedTestCertSpec struct { | ||||||
|  | 	TestCertSpec | ||||||
|  | 	explicitNames []string // as --tls-sni-cert-key explicit names | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func createTestCerts(spec TestCertSpec) (certFilePath, keyFilePath string, err error) { | ||||||
|  | 	var ips []net.IP | ||||||
|  | 	for _, ip := range spec.ips { | ||||||
|  | 		ips = append(ips, net.ParseIP(ip)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	certPem, keyPem, err := utilcert.GenerateSelfSignedCertKey(spec.host, ips, spec.names) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	certFile, err := ioutil.TempFile(os.TempDir(), "cert") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	keyFile, err := ioutil.TempFile(os.TempDir(), "key") | ||||||
|  | 	if err != nil { | ||||||
|  | 		os.Remove(certFile.Name()) | ||||||
|  | 		return "", "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = certFile.Write(certPem) | ||||||
|  | 	if err != nil { | ||||||
|  | 		os.Remove(certFile.Name()) | ||||||
|  | 		os.Remove(keyFile.Name()) | ||||||
|  | 		return "", "", err | ||||||
|  | 	} | ||||||
|  | 	certFile.Close() | ||||||
|  |  | ||||||
|  | 	_, err = keyFile.Write(keyPem) | ||||||
|  | 	if err != nil { | ||||||
|  | 		os.Remove(certFile.Name()) | ||||||
|  | 		os.Remove(keyFile.Name()) | ||||||
|  | 		return "", "", err | ||||||
|  | 	} | ||||||
|  | 	keyFile.Close() | ||||||
|  |  | ||||||
|  | 	return certFile.Name(), keyFile.Name(), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestGetNamedCertificateMap(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		certs         []NamedTestCertSpec | ||||||
|  | 		explicitNames []string | ||||||
|  | 		expected      map[string]int // name to certs[*] index | ||||||
|  | 		errorString   string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			// empty certs | ||||||
|  | 			expected: map[string]int{}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// only one cert | ||||||
|  | 			certs: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test.com", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]int{ | ||||||
|  | 				"test.com": 0, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// ips are ignored | ||||||
|  | 			certs: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test.com", | ||||||
|  | 						ips:  []string{"1.2.3.4"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]int{ | ||||||
|  | 				"test.com": 0, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// two certs with the same name | ||||||
|  | 			certs: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test.com", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test.com", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]int{ | ||||||
|  | 				"test.com": 0, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// two certs with different names | ||||||
|  | 			certs: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test2.com", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test1.com", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]int{ | ||||||
|  | 				"test1.com": 1, | ||||||
|  | 				"test2.com": 0, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// two certs with the same name, explicit trumps | ||||||
|  | 			certs: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test.com", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test.com", | ||||||
|  | 					}, | ||||||
|  | 					explicitNames: []string{"test.com"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]int{ | ||||||
|  | 				"test.com": 1, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// certs with partial overlap; ips are ignored | ||||||
|  | 			certs: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host:  "a", | ||||||
|  | 						names: []string{"a.test.com", "test.com"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host:  "b", | ||||||
|  | 						names: []string{"b.test.com", "test.com"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: map[string]int{ | ||||||
|  | 				"a": 0, "b": 1, | ||||||
|  | 				"a.test.com": 0, "b.test.com": 1, | ||||||
|  | 				"test.com": 0, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// wildcards | ||||||
|  | 			certs: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host:  "a", | ||||||
|  | 						names: []string{"a.test.com", "test.com"}, | ||||||
|  | 					}, | ||||||
|  | 					explicitNames: []string{"*.test.com", "test.com"}, | ||||||
|  | 				}, | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host:  "b", | ||||||
|  | 						names: []string{"b.test.com", "test.com"}, | ||||||
|  | 					}, | ||||||
|  | 					explicitNames: []string{"dev.test.com", "test.com"}, | ||||||
|  | 				}}, | ||||||
|  | 			expected: map[string]int{ | ||||||
|  | 				"test.com":     0, | ||||||
|  | 				"*.test.com":   0, | ||||||
|  | 				"dev.test.com": 1, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | NextTest: | ||||||
|  | 	for i, test := range tests { | ||||||
|  | 		var namedCertKeys []NamedCertKey | ||||||
|  | 		bySignature := map[string]int{} // index in test.certs by cert signature | ||||||
|  | 		for j, c := range test.certs { | ||||||
|  | 			certFile, keyFile, err := createTestCerts(c.TestCertSpec) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Errorf("%d - failed to create cert %d: %v", i, j, err) | ||||||
|  | 				continue NextTest | ||||||
|  | 			} | ||||||
|  | 			defer os.Remove(certFile) | ||||||
|  | 			defer os.Remove(keyFile) | ||||||
|  |  | ||||||
|  | 			namedCertKeys = append(namedCertKeys, NamedCertKey{ | ||||||
|  | 				CertKey: CertKey{ | ||||||
|  | 					KeyFile:  keyFile, | ||||||
|  | 					CertFile: certFile, | ||||||
|  | 				}, | ||||||
|  | 				Names: c.explicitNames, | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			sig, err := certFileSignature(certFile, keyFile) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Errorf("%d - failed to get signature for %d: %v", i, j, err) | ||||||
|  | 				continue NextTest | ||||||
|  | 			} | ||||||
|  | 			bySignature[sig] = j | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		certMap, err := getNamedCertificateMap(namedCertKeys) | ||||||
|  | 		if err == nil && len(test.errorString) != 0 { | ||||||
|  | 			t.Errorf("%d - expected no error, got: %v", i, err) | ||||||
|  | 		} else if err != nil && err.Error() != test.errorString { | ||||||
|  | 			t.Errorf("%d - expected error %q, got: %v", i, test.errorString, err) | ||||||
|  | 		} else { | ||||||
|  | 			got := map[string]int{} | ||||||
|  | 			for name, cert := range certMap { | ||||||
|  | 				x509Certs, err := x509.ParseCertificates(cert.Certificate[0]) | ||||||
|  | 				assert.NoError(t, err, "%d - invalid certificate for %q", i, name) | ||||||
|  | 				assert.True(t, len(x509Certs) > 0, "%d - expected at least one x509 cert in tls cert for %q", i, name) | ||||||
|  | 				got[name] = bySignature[x509CertSignature(x509Certs[0])] | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			assert.EqualValues(t, test.expected, got, "%d - wrong certificate map", i) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestServerRunWithSNI(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		Cert              TestCertSpec | ||||||
|  | 		SNICerts          []NamedTestCertSpec | ||||||
|  | 		ExpectedCertIndex int | ||||||
|  |  | ||||||
|  | 		// passed in the client hello info, "localhost" if unset | ||||||
|  | 		ServerName string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			// only one cert | ||||||
|  | 			Cert: TestCertSpec{ | ||||||
|  | 				host: "localhost", | ||||||
|  | 			}, | ||||||
|  | 			ExpectedCertIndex: -1, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// cert with multiple alternate names | ||||||
|  | 			Cert: TestCertSpec{ | ||||||
|  | 				host:  "localhost", | ||||||
|  | 				names: []string{"test.com"}, | ||||||
|  | 				ips:   []string{"127.0.0.1"}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedCertIndex: -1, | ||||||
|  | 			ServerName:        "test.com", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// one SNI and the default cert with the same name | ||||||
|  | 			Cert: TestCertSpec{ | ||||||
|  | 				host: "localhost", | ||||||
|  | 			}, | ||||||
|  | 			SNICerts: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "localhost", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedCertIndex: 0, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// matching SNI cert | ||||||
|  | 			Cert: TestCertSpec{ | ||||||
|  | 				host: "localhost", | ||||||
|  | 			}, | ||||||
|  | 			SNICerts: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test.com", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedCertIndex: 0, | ||||||
|  | 			ServerName:        "test.com", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// matching IP in SNI cert and the server cert. But IPs must not be | ||||||
|  | 			// passed via SNI. Hence, the ServerName in the HELLO packet is empty | ||||||
|  | 			// and the server should select the non-SNI cert. | ||||||
|  | 			Cert: TestCertSpec{ | ||||||
|  | 				host: "localhost", | ||||||
|  | 				ips:  []string{"10.0.0.1"}, | ||||||
|  | 			}, | ||||||
|  | 			SNICerts: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host: "test.com", | ||||||
|  | 						ips:  []string{"10.0.0.1"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedCertIndex: -1, | ||||||
|  | 			ServerName:        "10.0.0.1", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// wildcards | ||||||
|  | 			Cert: TestCertSpec{ | ||||||
|  | 				host: "localhost", | ||||||
|  | 			}, | ||||||
|  | 			SNICerts: []NamedTestCertSpec{ | ||||||
|  | 				{ | ||||||
|  | 					TestCertSpec: TestCertSpec{ | ||||||
|  | 						host:  "test.com", | ||||||
|  | 						names: []string{"*.test.com"}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			ExpectedCertIndex: 0, | ||||||
|  | 			ServerName:        "www.test.com", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | NextTest: | ||||||
|  | 	for i, test := range tests { | ||||||
|  | 		// create server cert | ||||||
|  | 		serverCertFile, serverKeyFile, err := createTestCerts(test.Cert) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("%d - failed to create server cert: %v", i, err) | ||||||
|  | 		} | ||||||
|  | 		defer os.Remove(serverCertFile) | ||||||
|  | 		defer os.Remove(serverKeyFile) | ||||||
|  |  | ||||||
|  | 		// create SNI certs | ||||||
|  | 		var namedCertKeys []NamedCertKey | ||||||
|  | 		serverSig, err := certFileSignature(serverCertFile, serverKeyFile) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("%d - failed to get server cert signature: %v", i, err) | ||||||
|  | 			continue NextTest | ||||||
|  | 		} | ||||||
|  | 		signatures := map[string]int{ | ||||||
|  | 			serverSig: -1, | ||||||
|  | 		} | ||||||
|  | 		for j, c := range test.SNICerts { | ||||||
|  | 			certFile, keyFile, err := createTestCerts(c.TestCertSpec) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Errorf("%d - failed to create SNI cert %d: %v", i, j, err) | ||||||
|  | 				continue NextTest | ||||||
|  | 			} | ||||||
|  | 			defer os.Remove(certFile) | ||||||
|  | 			defer os.Remove(keyFile) | ||||||
|  |  | ||||||
|  | 			namedCertKeys = append(namedCertKeys, NamedCertKey{ | ||||||
|  | 				CertKey: CertKey{ | ||||||
|  | 					KeyFile:  keyFile, | ||||||
|  | 					CertFile: certFile, | ||||||
|  | 				}, | ||||||
|  | 				Names: c.explicitNames, | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			// store index in namedCertKeys with the signature as the key | ||||||
|  | 			sig, err := certFileSignature(certFile, keyFile) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Errorf("%d - failed get SNI cert %d signature: %v", i, j, err) | ||||||
|  | 				continue NextTest | ||||||
|  | 			} | ||||||
|  | 			signatures[sig] = j | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		stopCh := make(chan struct{}) | ||||||
|  |  | ||||||
|  | 		// launch server | ||||||
|  | 		etcdserver, config, _ := setUp(t) | ||||||
|  | 		defer etcdserver.Terminate(t) | ||||||
|  |  | ||||||
|  | 		config.EnableIndex = true | ||||||
|  | 		config.SecureServingInfo = &SecureServingInfo{ | ||||||
|  | 			ServingInfo: ServingInfo{ | ||||||
|  | 				BindAddress: "localhost:0", | ||||||
|  | 			}, | ||||||
|  | 			ServerCert: GeneratableKeyCert{ | ||||||
|  | 				CertKey: CertKey{ | ||||||
|  | 					CertFile: serverCertFile, | ||||||
|  | 					KeyFile:  serverKeyFile, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			SNICerts: namedCertKeys, | ||||||
|  | 		} | ||||||
|  | 		config.InsecureServingInfo = nil | ||||||
|  |  | ||||||
|  | 		s, err := config.Complete().New() | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("%d - failed creating the server: %v", i, err) | ||||||
|  | 			continue NextTest | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if err := s.serveSecurely(stopCh); err != nil { | ||||||
|  | 			t.Errorf("%d - failed running the server: %v", i, err) | ||||||
|  | 			continue NextTest | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// load certificates into a pool | ||||||
|  | 		roots := x509.NewCertPool() | ||||||
|  | 		certFiles := []string{serverCertFile} | ||||||
|  | 		for _, c := range namedCertKeys { | ||||||
|  | 			certFiles = append(certFiles, c.CertFile) | ||||||
|  | 		} | ||||||
|  | 		for _, certFile := range certFiles { | ||||||
|  | 			bs, err := ioutil.ReadFile(certFile) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Errorf("%d - error reading %q: %v", i, certFile, err) | ||||||
|  | 				continue NextTest | ||||||
|  | 			} | ||||||
|  | 			if ok := roots.AppendCertsFromPEM(bs); !ok { | ||||||
|  | 				t.Errorf("%d - error adding cert %q to the pool", i, certFile) | ||||||
|  | 				continue NextTest | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// try to dial | ||||||
|  | 		addr := fmt.Sprintf("localhost:%d", s.effectiveSecurePort) | ||||||
|  | 		t.Logf("Dialing %s as %q", addr, test.ServerName) | ||||||
|  | 		conn, err := tls.Dial("tcp", addr, &tls.Config{ | ||||||
|  | 			RootCAs:    roots, | ||||||
|  | 			ServerName: test.ServerName, // used for SNI in the client HELLO packet | ||||||
|  | 		}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("%d - failed to connect: %v", i, err) | ||||||
|  | 			continue NextTest | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// check returned server certificate | ||||||
|  | 		sig := x509CertSignature(conn.ConnectionState().PeerCertificates[0]) | ||||||
|  | 		gotCertIndex, found := signatures[sig] | ||||||
|  | 		if !found { | ||||||
|  | 			t.Errorf("%d - unknown signature returned from server: %s", i, sig) | ||||||
|  | 		} | ||||||
|  | 		if gotCertIndex != test.ExpectedCertIndex { | ||||||
|  | 			t.Errorf("%d - expected cert index %d, got cert index %d", i, test.ExpectedCertIndex, gotCertIndex) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		conn.Close() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func x509CertSignature(cert *x509.Certificate) string { | ||||||
|  | 	return base64.StdEncoding.EncodeToString(cert.Signature) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func certFileSignature(certFile, keyFile string) (string, error) { | ||||||
|  | 	cert, err := tls.LoadX509KeyPair(certFile, keyFile) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	x509Certs, err := x509.ParseCertificates(cert.Certificate[0]) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	if len(x509Certs) == 0 { | ||||||
|  | 		return "", fmt.Errorf("expected at least one cert after reparsing cert %q", certFile) | ||||||
|  | 	} | ||||||
|  | 	return x509CertSignature(x509Certs[0]), nil | ||||||
|  | } | ||||||
| @@ -126,22 +126,19 @@ func MakeEllipticPrivateKeyPEM() ([]byte, error) { | |||||||
| 	return pem.EncodeToMemory(privateKeyPemBlock), nil | 	return pem.EncodeToMemory(privateKeyPemBlock), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // GenerateSelfSignedCert creates a self-signed certificate and key for the given host. | // GenerateSelfSignedCertKey creates a self-signed certificate and key for the given host. | ||||||
| // Host may be an IP or a DNS name | // Host may be an IP or a DNS name | ||||||
| // You may also specify additional subject alt names (either ip or dns names) for the certificate | // You may also specify additional subject alt names (either ip or dns names) for the certificate | ||||||
| // The certificate will be created with file mode 0644. The key will be created with file mode 0600. | func GenerateSelfSignedCertKey(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error) { | ||||||
| // If the certificate or key files already exist, they will be overwritten. |  | ||||||
| // Any parent directories of the certPath or keyPath will be created as needed with file mode 0755. |  | ||||||
| func GenerateSelfSignedCert(host, certPath, keyPath string, alternateIPs []net.IP, alternateDNS []string) error { |  | ||||||
| 	priv, err := rsa.GenerateKey(cryptorand.Reader, 2048) | 	priv, err := rsa.GenerateKey(cryptorand.Reader, 2048) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	template := x509.Certificate{ | 	template := x509.Certificate{ | ||||||
| 		SerialNumber: big.NewInt(1), | 		SerialNumber: big.NewInt(1), | ||||||
| 		Subject: pkix.Name{ | 		Subject: pkix.Name{ | ||||||
| 			CommonName: fmt.Sprintf("%s@%d", host, time.Now().Unix()), | 			CommonName: host, | ||||||
| 		}, | 		}, | ||||||
| 		NotBefore: time.Now(), | 		NotBefore: time.Now(), | ||||||
| 		NotAfter:  time.Now().Add(time.Hour * 24 * 365), | 		NotAfter:  time.Now().Add(time.Hour * 24 * 365), | ||||||
| @@ -163,30 +160,22 @@ func GenerateSelfSignedCert(host, certPath, keyPath string, alternateIPs []net.I | |||||||
|  |  | ||||||
| 	derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv) | 	derBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, &priv.PublicKey, priv) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Generate cert | 	// Generate cert | ||||||
| 	certBuffer := bytes.Buffer{} | 	certBuffer := bytes.Buffer{} | ||||||
| 	if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { | 	if err := pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { | ||||||
| 		return err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Generate key | 	// Generate key | ||||||
| 	keyBuffer := bytes.Buffer{} | 	keyBuffer := bytes.Buffer{} | ||||||
| 	if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil { | 	if err := pem.Encode(&keyBuffer, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil { | ||||||
| 		return err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := WriteCert(certPath, certBuffer.Bytes()); err != nil { | 	return certBuffer.Bytes(), keyBuffer.Bytes(), nil | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if err := WriteKey(keyPath, keyBuffer.Bytes()); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // FormatBytesCert receives byte array certificate and formats in human-readable format | // FormatBytesCert receives byte array certificate and formats in human-readable format | ||||||
|   | |||||||
							
								
								
									
										113
									
								
								pkg/util/config/namedcertkey_flag.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								pkg/util/config/namedcertkey_flag.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | |||||||
|  | /* | ||||||
|  | Copyright 2016 The Kubernetes Authors. | ||||||
|  |  | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | package config | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"flag" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // NamedCertKey is a flag value parsing "certfile,keyfile" and "certfile,keyfile:name,name,name". | ||||||
|  | type NamedCertKey struct { | ||||||
|  | 	Names             []string | ||||||
|  | 	CertFile, KeyFile string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var _ flag.Value = &NamedCertKey{} | ||||||
|  |  | ||||||
|  | func (nkc *NamedCertKey) String() string { | ||||||
|  | 	s := nkc.CertFile + "," + nkc.KeyFile | ||||||
|  | 	if len(nkc.Names) > 0 { | ||||||
|  | 		s = s + ":" + strings.Join(nkc.Names, ",") | ||||||
|  | 	} | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (nkc *NamedCertKey) Set(value string) error { | ||||||
|  | 	cs := strings.SplitN(value, ":", 2) | ||||||
|  | 	var keycert string | ||||||
|  | 	if len(cs) == 2 { | ||||||
|  | 		var names string | ||||||
|  | 		keycert, names = strings.TrimSpace(cs[0]), strings.TrimSpace(cs[1]) | ||||||
|  | 		if names == "" { | ||||||
|  | 			return errors.New("empty names list is not allowed") | ||||||
|  | 		} | ||||||
|  | 		nkc.Names = nil | ||||||
|  | 		for _, name := range strings.Split(names, ",") { | ||||||
|  | 			nkc.Names = append(nkc.Names, strings.TrimSpace(name)) | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		nkc.Names = nil | ||||||
|  | 		keycert = strings.TrimSpace(cs[0]) | ||||||
|  | 	} | ||||||
|  | 	cs = strings.Split(keycert, ",") | ||||||
|  | 	if len(cs) != 2 { | ||||||
|  | 		return errors.New("expected comma separated certificate and key file paths") | ||||||
|  | 	} | ||||||
|  | 	nkc.CertFile = strings.TrimSpace(cs[0]) | ||||||
|  | 	nkc.KeyFile = strings.TrimSpace(cs[1]) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (*NamedCertKey) Type() string { | ||||||
|  | 	return "namedCertKey" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NamedCertKeyArray is a flag value parsing NamedCertKeys, each passed with its own | ||||||
|  | // flag instance (in contrast to comma separated slices). | ||||||
|  | type NamedCertKeyArray struct { | ||||||
|  | 	value   *[]NamedCertKey | ||||||
|  | 	changed bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var _ flag.Value = &NamedCertKey{} | ||||||
|  |  | ||||||
|  | // NewNamedKeyCertArray creates a new NamedCertKeyArray with the internal value | ||||||
|  | // pointing to p. | ||||||
|  | func NewNamedCertKeyArray(p *[]NamedCertKey) *NamedCertKeyArray { | ||||||
|  | 	return &NamedCertKeyArray{ | ||||||
|  | 		value: p, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *NamedCertKeyArray) Set(val string) error { | ||||||
|  | 	nkc := NamedCertKey{} | ||||||
|  | 	err := nkc.Set(val) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if !a.changed { | ||||||
|  | 		*a.value = []NamedCertKey{nkc} | ||||||
|  | 		a.changed = true | ||||||
|  | 	} else { | ||||||
|  | 		*a.value = append(*a.value, nkc) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *NamedCertKeyArray) Type() string { | ||||||
|  | 	return "namedCertKey" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *NamedCertKeyArray) String() string { | ||||||
|  | 	nkcs := make([]string, 0, len(*a.value)) | ||||||
|  | 	for i := range *a.value { | ||||||
|  | 		nkcs = append(nkcs, (*a.value)[i].String()) | ||||||
|  | 	} | ||||||
|  | 	return "[" + strings.Join(nkcs, ";") + "]" | ||||||
|  | } | ||||||
							
								
								
									
										138
									
								
								pkg/util/config/namedcertkey_flag_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								pkg/util/config/namedcertkey_flag_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | |||||||
|  | /* | ||||||
|  | Copyright 2016 The Kubernetes Authors. | ||||||
|  |  | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | package config | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/spf13/pflag" | ||||||
|  | 	"reflect" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestNamedCertKeyArrayFlag(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		args       []string | ||||||
|  | 		def        []NamedCertKey | ||||||
|  | 		expected   []NamedCertKey | ||||||
|  | 		parseError string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			args:     []string{}, | ||||||
|  | 			expected: nil, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args: []string{"foo.crt,foo.key"}, | ||||||
|  | 			expected: []NamedCertKey{{ | ||||||
|  | 				KeyFile:  "foo.key", | ||||||
|  | 				CertFile: "foo.crt", | ||||||
|  | 			}}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args: []string{"  foo.crt , foo.key    "}, | ||||||
|  | 			expected: []NamedCertKey{{ | ||||||
|  | 				KeyFile:  "foo.key", | ||||||
|  | 				CertFile: "foo.crt", | ||||||
|  | 			}}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args: []string{"foo.crt,foo.key:abc"}, | ||||||
|  | 			expected: []NamedCertKey{{ | ||||||
|  | 				KeyFile:  "foo.key", | ||||||
|  | 				CertFile: "foo.crt", | ||||||
|  | 				Names:    []string{"abc"}, | ||||||
|  | 			}}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args: []string{"foo.crt,foo.key: abc  "}, | ||||||
|  | 			expected: []NamedCertKey{{ | ||||||
|  | 				KeyFile:  "foo.key", | ||||||
|  | 				CertFile: "foo.crt", | ||||||
|  | 				Names:    []string{"abc"}, | ||||||
|  | 			}}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args:       []string{"foo.crt,foo.key:"}, | ||||||
|  | 			parseError: "empty names list is not allowed", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args:       []string{""}, | ||||||
|  | 			parseError: "expected comma separated certificate and key file paths", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args:       []string{"   "}, | ||||||
|  | 			parseError: "expected comma separated certificate and key file paths", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args:       []string{"a,b,c"}, | ||||||
|  | 			parseError: "expected comma separated certificate and key file paths", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args: []string{"foo.crt,foo.key:abc,def,ghi"}, | ||||||
|  | 			expected: []NamedCertKey{{ | ||||||
|  | 				KeyFile:  "foo.key", | ||||||
|  | 				CertFile: "foo.crt", | ||||||
|  | 				Names:    []string{"abc", "def", "ghi"}, | ||||||
|  | 			}}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args: []string{"foo.crt,foo.key:*.*.*"}, | ||||||
|  | 			expected: []NamedCertKey{{ | ||||||
|  | 				KeyFile:  "foo.key", | ||||||
|  | 				CertFile: "foo.crt", | ||||||
|  | 				Names:    []string{"*.*.*"}, | ||||||
|  | 			}}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			args: []string{"foo.crt,foo.key", "bar.crt,bar.key"}, | ||||||
|  | 			expected: []NamedCertKey{{ | ||||||
|  | 				KeyFile:  "foo.key", | ||||||
|  | 				CertFile: "foo.crt", | ||||||
|  | 			}, { | ||||||
|  | 				KeyFile:  "bar.key", | ||||||
|  | 				CertFile: "bar.crt", | ||||||
|  | 			}}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for i, test := range tests { | ||||||
|  | 		fs := pflag.NewFlagSet("testNamedCertKeyArray", pflag.ContinueOnError) | ||||||
|  | 		var nkcs []NamedCertKey | ||||||
|  | 		for _, d := range test.def { | ||||||
|  | 			nkcs = append(nkcs, d) | ||||||
|  | 		} | ||||||
|  | 		fs.Var(NewNamedCertKeyArray(&nkcs), "tls-sni-cert-key", "usage") | ||||||
|  |  | ||||||
|  | 		args := []string{} | ||||||
|  | 		for _, a := range test.args { | ||||||
|  | 			args = append(args, fmt.Sprintf("--tls-sni-cert-key=%s", a)) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		err := fs.Parse(args) | ||||||
|  | 		if test.parseError != "" { | ||||||
|  | 			if err == nil { | ||||||
|  | 				t.Errorf("%d: expected error %q, got nil", i, test.parseError) | ||||||
|  | 			} else if !strings.Contains(err.Error(), test.parseError) { | ||||||
|  | 				t.Errorf("%d: expected error %q, got %q", i, test.parseError, err) | ||||||
|  | 			} | ||||||
|  | 		} else if err != nil { | ||||||
|  | 			t.Errorf("%d: expected nil error, got %v", i, err) | ||||||
|  | 		} | ||||||
|  | 		if !reflect.DeepEqual(nkcs, test.expected) { | ||||||
|  | 			t.Errorf("%d: expected %+v, got %+v", i, test.expected, nkcs) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user