diff --git a/staging/src/k8s.io/client-go/util/cert/BUILD b/staging/src/k8s.io/client-go/util/cert/BUILD index 661c5e35104..514a2572fc2 100644 --- a/staging/src/k8s.io/client-go/util/cert/BUILD +++ b/staging/src/k8s.io/client-go/util/cert/BUILD @@ -21,6 +21,7 @@ go_library( "csr.go", "io.go", "pem.go", + "server_inspection.go", ], importmap = "k8s.io/kubernetes/vendor/k8s.io/client-go/util/cert", importpath = "k8s.io/client-go/util/cert", diff --git a/staging/src/k8s.io/client-go/util/cert/server_inspection.go b/staging/src/k8s.io/client-go/util/cert/server_inspection.go new file mode 100644 index 00000000000..6d228916d19 --- /dev/null +++ b/staging/src/k8s.io/client-go/util/cert/server_inspection.go @@ -0,0 +1,97 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cert + +import ( + "crypto/tls" + "crypto/x509" + "net/url" + "strings" +) + +// GetClientCANames gets the CA names for client certs that a server accepts. This is useful when inspecting the +// state of particular servers. apiHost is "host:port" +func GetClientCANames(apiHost string) ([]string, error) { + // when we run this the second time, we know which one we are expecting + acceptableCAs := []string{} + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, // this is insecure to always get to the GetClientCertificate + GetClientCertificate: func(hello *tls.CertificateRequestInfo) (*tls.Certificate, error) { + acceptableCAs = []string{} + for _, curr := range hello.AcceptableCAs { + acceptableCAs = append(acceptableCAs, string(curr)) + } + return &tls.Certificate{}, nil + }, + } + + conn, err := tls.Dial("tcp", apiHost, tlsConfig) + if err != nil { + return nil, err + } + defer conn.Close() + + return acceptableCAs, nil +} + +// GetClientCANamesForURL is GetClientCANames against a URL string like we use in kubeconfigs +func GetClientCANamesForURL(kubeConfigURL string) ([]string, error) { + apiserverURL, err := url.Parse(kubeConfigURL) + if err != nil { + return nil, err + } + return GetClientCANames(apiserverURL.Host) +} + +// GetServingCertificates returns the x509 certs used by a server. The serverName is optional for specifying a different +// name to get SNI certificates. apiHost is "host:port" +func GetServingCertificates(apiHost, serverName string) ([]*x509.Certificate, [][]byte, error) { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, // this is insecure so that we always get connected + } + // if a name is specified for SNI, set it. + if len(serverName) > 0 { + tlsConfig.ServerName = serverName + } + + conn, err := tls.Dial("tcp", apiHost, tlsConfig) + if err != nil { + return nil, nil, err + } + conn.Close() + + peerCerts := conn.ConnectionState().PeerCertificates + peerCertBytes := [][]byte{} + for _, a := range peerCerts { + actualCert, err := EncodeCertificates(a) + if err != nil { + return nil, nil, err + } + peerCertBytes = append(peerCertBytes, []byte(strings.TrimSpace(string(actualCert)))) + } + + return peerCerts, peerCertBytes, err +} + +// GetServingCertificatesForURL is GetServingCertificates against a URL string like we use in kubeconfigs +func GetServingCertificatesForURL(kubeConfigURL, serverName string) ([]*x509.Certificate, [][]byte, error) { + apiserverURL, err := url.Parse(kubeConfigURL) + if err != nil { + return nil, nil, err + } + return GetServingCertificates(apiserverURL.Host, serverName) +} diff --git a/test/integration/apiserver/certreload/BUILD b/test/integration/apiserver/certreload/BUILD index 5eb80156cbf..b1e7faa529d 100644 --- a/test/integration/apiserver/certreload/BUILD +++ b/test/integration/apiserver/certreload/BUILD @@ -14,6 +14,7 @@ go_test( "//staging/src/k8s.io/apimachinery/pkg/util/wait:go_default_library", "//staging/src/k8s.io/apiserver/pkg/server/dynamiccertificates:go_default_library", "//staging/src/k8s.io/client-go/kubernetes:go_default_library", + "//staging/src/k8s.io/client-go/util/cert:go_default_library", "//staging/src/k8s.io/component-base/cli/flag:go_default_library", "//test/integration/framework:go_default_library", ], diff --git a/test/integration/apiserver/certreload/certreload_test.go b/test/integration/apiserver/certreload/certreload_test.go index 306c827993d..a23ae31703c 100644 --- a/test/integration/apiserver/certreload/certreload_test.go +++ b/test/integration/apiserver/certreload/certreload_test.go @@ -18,11 +18,8 @@ package podlogs import ( "bytes" - "crypto/tls" - "crypto/x509" - "encoding/base64" + "fmt" "io/ioutil" - "net/url" "path" "strings" "testing" @@ -33,6 +30,7 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apiserver/pkg/server/dynamiccertificates" "k8s.io/client-go/kubernetes" + "k8s.io/client-go/util/cert" "k8s.io/component-base/cli/flag" "k8s.io/kubernetes/cmd/kube-apiserver/app/options" "k8s.io/kubernetes/test/integration/framework" @@ -88,13 +86,9 @@ MnVCuBwfwDXCAiEAw/1TA+CjPq9JC5ek1ifR0FybTURjeQqYkKpve1dveps= dynamiccertificates.FileRefreshDuration = 1 * time.Second }, }) - apiserverURL, err := url.Parse(kubeconfig.Host) - if err != nil { - t.Fatal(err) - } // wait for request header info - err = wait.PollImmediate(100*time.Millisecond, 30*time.Second, waitForConfigMapCAContent(t, kubeClient, "requestheader-client-ca-file", "-----BEGIN CERTIFICATE-----", 1)) + err := wait.PollImmediate(100*time.Millisecond, 30*time.Second, waitForConfigMapCAContent(t, kubeClient, "requestheader-client-ca-file", "-----BEGIN CERTIFICATE-----", 1)) if err != nil { t.Fatal(err) } @@ -105,24 +99,6 @@ MnVCuBwfwDXCAiEAw/1TA+CjPq9JC5ek1ifR0FybTURjeQqYkKpve1dveps= } // when we run this the second time, we know which one we are expecting - acceptableCAs := []string{} - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - GetClientCertificate: func(hello *tls.CertificateRequestInfo) (*tls.Certificate, error) { - acceptableCAs = []string{} - for _, curr := range hello.AcceptableCAs { - acceptableCAs = append(acceptableCAs, string(curr)) - } - return &tls.Certificate{}, nil - }, - } - - conn, err := tls.Dial("tcp", apiserverURL.Host, tlsConfig) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - if err := ioutil.WriteFile(clientCAFilename, differentClientCA, 0644); err != nil { t.Fatal(err) } @@ -132,11 +108,10 @@ MnVCuBwfwDXCAiEAw/1TA+CjPq9JC5ek1ifR0FybTURjeQqYkKpve1dveps= time.Sleep(4 * time.Second) - conn2, err := tls.Dial("tcp", apiserverURL.Host, tlsConfig) + acceptableCAs, err := cert.GetClientCANamesForURL(kubeconfig.Host) if err != nil { t.Fatal(err) } - defer conn2.Close() expectedCAs := []string{"webhook-test.default.svc", "My Client"} if len(expectedCAs) != len(acceptableCAs) { @@ -334,29 +309,6 @@ func TestServingCert(t *testing.T) { dynamiccertificates.FileRefreshDuration = 1 * time.Second }, }) - apiserverURL, err := url.Parse(kubeconfig.Host) - if err != nil { - t.Fatal(err) - } - - // when we run this the second time, we know which one we are expecting - acceptableCerts := [][]byte{} - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - acceptableCerts = make([][]byte, 0, len(rawCerts)) - for _, r := range rawCerts { - acceptableCerts = append(acceptableCerts, r) - } - return nil - }, - } - - conn, err := tls.Dial("tcp", apiserverURL.Host, tlsConfig) - if err != nil { - t.Fatal(err) - } - defer conn.Close() if err := ioutil.WriteFile(path.Join(servingCertPath, "apiserver.key"), serverKey, 0644); err != nil { t.Fatal(err) @@ -367,30 +319,14 @@ func TestServingCert(t *testing.T) { time.Sleep(4 * time.Second) - conn2, err := tls.Dial("tcp", apiserverURL.Host, tlsConfig) + // get the certs we're actually serving with + _, actualCerts, err := cert.GetServingCertificatesForURL(kubeconfig.Host, "") if err != nil { t.Fatal(err) } - defer conn2.Close() - - cert, err := tls.X509KeyPair(serverCert, serverKey) - if err != nil { + if err := checkServingCerts(serverCert, actualCerts); err != nil { t.Fatal(err) } - - expectedCerts := cert.Certificate - if len(expectedCerts) != len(acceptableCerts) { - var certs []string - for _, a := range acceptableCerts { - certs = append(certs, base64.StdEncoding.EncodeToString(a)) - } - t.Fatalf("Unexpected number of certs: %v", strings.Join(certs, ":")) - } - for i := range expectedCerts { - if !bytes.Equal(acceptableCerts[i], expectedCerts[i]) { - t.Errorf("expected %q, got %q", base64.StdEncoding.EncodeToString(expectedCerts[i]), base64.StdEncoding.EncodeToString(acceptableCerts[i])) - } - } } func TestSNICert(t *testing.T) { @@ -419,50 +355,16 @@ func TestSNICert(t *testing.T) { }} }, }) - apiserverURL, err := url.Parse(kubeconfig.Host) - if err != nil { - t.Fatal(err) - } // When we run this the second time, we know which one we are expecting. - acceptableCerts := [][]byte{} - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - acceptableCerts = make([][]byte, 0, len(rawCerts)) - for _, r := range rawCerts { - acceptableCerts = append(acceptableCerts, r) - } - return nil - }, - ServerName: "foo", - } - - conn, err := tls.Dial("tcp", apiserverURL.Host, tlsConfig) + _, actualCerts, err := cert.GetServingCertificatesForURL(kubeconfig.Host, "foo") if err != nil { t.Fatal(err) } - conn.Close() - - cert, err := tls.LoadX509KeyPair(path.Join(servingCertPath, "foo.crt"), path.Join(servingCertPath, "foo.key")) - if err != nil { + if err := checkServingCerts(anotherServerCert, actualCerts); err != nil { t.Fatal(err) } - expectedCerts := cert.Certificate - if len(expectedCerts) != len(acceptableCerts) { - var certs []string - for _, a := range acceptableCerts { - certs = append(certs, base64.StdEncoding.EncodeToString(a)) - } - t.Fatalf("Unexpected number of certs: %v", strings.Join(certs, ":")) - } - for i := range expectedCerts { - if !bytes.Equal(acceptableCerts[i], expectedCerts[i]) { - t.Errorf("expected %q, got %q", base64.StdEncoding.EncodeToString(expectedCerts[i]), base64.StdEncoding.EncodeToString(acceptableCerts[i])) - } - } - if err := ioutil.WriteFile(path.Join(servingCertPath, "foo.key"), serverKey, 0644); err != nil { t.Fatal(err) } @@ -472,28 +374,40 @@ func TestSNICert(t *testing.T) { time.Sleep(4 * time.Second) - conn2, err := tls.Dial("tcp", apiserverURL.Host, tlsConfig) + _, actualCerts, err = cert.GetServingCertificatesForURL(kubeconfig.Host, "foo") if err != nil { t.Fatal(err) } - conn2.Close() - - cert, err = tls.X509KeyPair(serverCert, serverKey) - if err != nil { + if err := checkServingCerts(serverCert, actualCerts); err != nil { t.Fatal(err) } - - expectedCerts = cert.Certificate - if len(expectedCerts) != len(acceptableCerts) { - var certs []string - for _, a := range acceptableCerts { - certs = append(certs, base64.StdEncoding.EncodeToString(a)) - } - t.Fatalf("Unexpected number of certs: %v", strings.Join(certs, ":")) - } - for i := range expectedCerts { - if !bytes.Equal(acceptableCerts[i], expectedCerts[i]) { - t.Errorf("expected %q, got %q", base64.StdEncoding.EncodeToString(expectedCerts[i]), base64.StdEncoding.EncodeToString(acceptableCerts[i])) - } - } +} + +func checkServingCerts(expectedBytes []byte, actual [][]byte) error { + expectedCerts, err := cert.ParseCertsPEM(expectedBytes) + if err != nil { + return err + } + expected := [][]byte{} + for _, curr := range expectedCerts { + currBytes, err := cert.EncodeCertificates(curr) + if err != nil { + return err + } + expected = append(expected, []byte(strings.TrimSpace(string(currBytes)))) + } + + if len(expected) != len(actual) { + var certs []string + for _, a := range actual { + certs = append(certs, string(a)) + } + return fmt.Errorf("unexpected number of certs %d vs %d: %v", len(expected), len(actual), strings.Join(certs, "\n")) + } + for i := range expected { + if !bytes.Equal(actual[i], expected[i]) { + return fmt.Errorf("expected %q, got %q", string(expected[i]), string(actual[i])) + } + } + return nil }