From 64bc96baf9b711ce8b58f70b7f569c8324382523 Mon Sep 17 00:00:00 2001 From: Maria Ntalla Date: Wed, 6 Jun 2018 16:22:29 +0100 Subject: [PATCH] Setup test for verifying by checking certificate fingerprints --- .../vsphere/vclib/connection_test.go | 69 +++++++++++++++---- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/pkg/cloudprovider/providers/vsphere/vclib/connection_test.go b/pkg/cloudprovider/providers/vsphere/vclib/connection_test.go index 1af398990a9..3fc7f52eace 100644 --- a/pkg/cloudprovider/providers/vsphere/vclib/connection_test.go +++ b/pkg/cloudprovider/providers/vsphere/vclib/connection_test.go @@ -18,8 +18,10 @@ package vclib_test import ( "context" + "crypto/sha1" "crypto/tls" "crypto/x509" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -30,7 +32,7 @@ import ( "k8s.io/kubernetes/pkg/cloudprovider/providers/vsphere/vclib/fixtures" ) -func createTestServer(t *testing.T, caCertPath, serverCertPath, serverKeyPath string, handler http.HandlerFunc) *httptest.Server { +func createTestServer(t *testing.T, caCertPath, serverCertPath, serverKeyPath string, handler http.HandlerFunc) (*httptest.Server, string) { caCertPEM, err := ioutil.ReadFile(caCertPath) if err != nil { t.Fatalf("Could not read ca cert from file") @@ -54,22 +56,20 @@ func createTestServer(t *testing.T, caCertPath, serverCertPath, serverKeyPath st RootCAs: certPool, } - return server + // calculate the leaf certificate's fingerprint + x509LeafCert := server.TLS.Certificates[0].Certificate[0] + tpBytes := sha1.Sum(x509LeafCert) + tpString := fmt.Sprintf("%x", tpBytes) + + return server, tpString } func TestWithValidCaCert(t *testing.T) { - gotRequest := false - handler := func(w http.ResponseWriter, r *http.Request) { - gotRequest = true - } + handler, verify := getRequestVerifier(t) - server := createTestServer(t, fixtures.CaCertPath, fixtures.ServerCertPath, fixtures.ServerKeyPath, handler) + server, _ := createTestServer(t, fixtures.CaCertPath, fixtures.ServerCertPath, fixtures.ServerKeyPath, handler) server.StartTLS() - - u, err := url.Parse(server.URL) - if err != nil { - t.Fatalf("Cannot parse URL: %v", err) - } + u := mustParseUrl(t, server.URL) connection := &vclib.VSphereConnection{ Hostname: u.Hostname(), @@ -80,9 +80,26 @@ func TestWithValidCaCert(t *testing.T) { // Ignoring error here, because we only care about the TLS connection connection.NewClient(context.Background()) - if !gotRequest { - t.Fatalf("Never saw a request, TLS connection could not be established") + verify() +} + +func TestWithValidThumbprint(t *testing.T) { + handler, verify := getRequestVerifier(t) + + server, serverThumbprint := createTestServer(t, fixtures.CaCertPath, fixtures.ServerCertPath, fixtures.ServerKeyPath, handler) + server.StartTLS() + u := mustParseUrl(t, server.URL) + + connection := &vclib.VSphereConnection{ + Hostname: u.Hostname(), + Port: u.Port(), + Thumbprint: serverThumbprint, } + + // Ignoring error here, because we only care about the TLS connection + connection.NewClient(context.Background()) + + verify() } func TestWithInvalidCaCertPath(t *testing.T) { @@ -133,3 +150,27 @@ type fakeTransport struct{} func (ft fakeTransport) RoundTrip(*http.Request) (*http.Response, error) { return nil, nil } + +func getRequestVerifier(t *testing.T) (http.HandlerFunc, func()) { + gotRequest := false + + handler := func(w http.ResponseWriter, r *http.Request) { + gotRequest = true + } + + checker := func() { + if !gotRequest { + t.Fatalf("Never saw a request, maybe TLS connection could not be established?") + } + } + + return handler, checker +} + +func mustParseUrl(t *testing.T, i string) *url.URL { + u, err := url.Parse(i) + if err != nil { + t.Fatalf("Cannot parse URL: %v", err) + } + return u +}