From 1043126135231cef469f7055f561869d0edaf6e6 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Fri, 9 Oct 2015 01:18:16 -0400 Subject: [PATCH] Refactor SSH tunneling, fix proxy transport TLS/Dial extraction --- cmd/kube-apiserver/app/server.go | 37 +++- pkg/apiserver/api_installer.go | 3 +- pkg/apiserver/apiserver.go | 2 - pkg/apiserver/proxy.go | 55 +---- pkg/apiserver/proxy_test.go | 16 ++ pkg/client/chaosclient/chaosclient.go | 8 + pkg/client/unversioned/debugging.go | 7 + pkg/client/unversioned/transport.go | 20 ++ pkg/master/master.go | 253 +++-------------------- pkg/master/master_test.go | 128 ++---------- pkg/master/tunneler.go | 262 ++++++++++++++++++++++++ pkg/master/tunneler_test.go | 139 +++++++++++++ pkg/registry/generic/rest/proxy.go | 85 +------- pkg/registry/generic/rest/proxy_test.go | 16 ++ pkg/registry/node/etcd/etcd.go | 9 +- pkg/registry/node/etcd/etcd_test.go | 2 +- pkg/registry/node/strategy.go | 8 +- pkg/registry/pod/etcd/etcd.go | 14 +- pkg/registry/pod/etcd/etcd_test.go | 4 +- pkg/registry/pod/strategy.go | 12 +- pkg/registry/service/rest.go | 14 +- pkg/registry/service/rest_test.go | 2 +- pkg/util/http.go | 40 ++++ pkg/util/proxy/dial.go | 106 ++++++++++ pkg/util/proxy/transport.go | 7 + test/e2e/kubectl.go | 3 + 26 files changed, 739 insertions(+), 513 deletions(-) create mode 100644 pkg/master/tunneler.go create mode 100644 pkg/master/tunneler_test.go create mode 100644 pkg/util/proxy/dial.go diff --git a/cmd/kube-apiserver/app/server.go b/cmd/kube-apiserver/app/server.go index 6419497c053..19245c45cd8 100644 --- a/cmd/kube-apiserver/app/server.go +++ b/cmd/kube-apiserver/app/server.go @@ -376,6 +376,30 @@ func (s *APIServer) Run(_ []string) error { glog.Fatalf("Cloud provider could not be initialized: %v", err) } + // Setup tunneler if needed + var tunneler master.Tunneler + var proxyDialerFn apiserver.ProxyDialerFunc + if len(s.SSHUser) > 0 { + // Get ssh key distribution func, if supported + var installSSH master.InstallSSHKey + if cloud != nil { + if instances, supported := cloud.Instances(); supported { + installSSH = instances.AddSSHKeyToAllInstances + } + } + + // Set up the tunneler + tunneler = master.NewSSHTunneler(s.SSHUser, s.SSHKeyfile, installSSH) + + // Use the tunneler's dialer to connect to the kubelet + s.KubeletConfig.Dial = tunneler.Dial + // Use the tunneler's dialer when proxying to pods, services, and nodes + proxyDialerFn = tunneler.Dial + } + + // Proxying to pods and services is IP-based... don't expect to be able to verify the hostname + proxyTLSClientConfig := &tls.Config{InsecureSkipVerify: true} + kubeletClient, err := client.NewKubeletClient(&s.KubeletConfig) if err != nil { glog.Fatalf("Failure to start kubelet client: %v", err) @@ -508,12 +532,7 @@ func (s *APIServer) Run(_ []string) error { } } } - var installSSH master.InstallSSHKey - if cloud != nil { - if instances, supported := cloud.Instances(); supported { - installSSH = instances.AddSSHKeyToAllInstances - } - } + config := &master.Config{ StorageDestinations: storageDestinations, StorageVersions: storageVersions, @@ -542,9 +561,9 @@ func (s *APIServer) Run(_ []string) error { ClusterName: s.ClusterName, ExternalHost: s.ExternalHost, MinRequestTimeout: s.MinRequestTimeout, - SSHUser: s.SSHUser, - SSHKeyfile: s.SSHKeyfile, - InstallSSHKey: installSSH, + ProxyDialer: proxyDialerFn, + ProxyTLSClientConfig: proxyTLSClientConfig, + Tunneler: tunneler, ServiceNodePortRange: s.ServiceNodePortRange, KubernetesServiceNodePort: s.KubernetesServiceNodePort, } diff --git a/pkg/apiserver/api_installer.go b/pkg/apiserver/api_installer.go index 54a2c188f13..bd35e8b32f5 100644 --- a/pkg/apiserver/api_installer.go +++ b/pkg/apiserver/api_installer.go @@ -41,7 +41,6 @@ type APIInstaller struct { info *APIRequestInfoResolver prefix string // Path prefix where API resources are to be registered. minRequestTimeout time.Duration - proxyDialerFn ProxyDialerFunc } // Struct capturing information about an action ("GET", "POST", "WATCH", PROXY", etc). @@ -64,7 +63,7 @@ var errEmptyName = errors.NewBadRequest("name must be provided") func (a *APIInstaller) Install(ws *restful.WebService) (apiResources []api.APIResource, errors []error) { errors = make([]error, 0) - proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.Storage, a.group.Codec, a.group.Context, a.info, a.proxyDialerFn}) + proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.Storage, a.group.Codec, a.group.Context, a.info}) // Register the paths in a deterministic (sorted) order to get a deterministic swagger spec. paths := make([]string, len(a.group.Storage)) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 4314e0d1320..3e6daf34d61 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -105,7 +105,6 @@ type APIGroupVersion struct { Admit admission.Interface Context api.RequestContextMapper - ProxyDialerFn ProxyDialerFunc MinRequestTimeout time.Duration } @@ -164,7 +163,6 @@ func (g *APIGroupVersion) newInstaller() *APIInstaller { info: g.APIRequestInfoResolver, prefix: prefix, minRequestTimeout: g.MinRequestTimeout, - proxyDialerFn: g.ProxyDialerFn, } return installer } diff --git a/pkg/apiserver/proxy.go b/pkg/apiserver/proxy.go index be33a05c588..3398ab71ff3 100644 --- a/pkg/apiserver/proxy.go +++ b/pkg/apiserver/proxy.go @@ -17,11 +17,8 @@ limitations under the License. package apiserver import ( - "crypto/tls" - "fmt" "io" "math/rand" - "net" "net/http" "net/http/httputil" "net/url" @@ -40,7 +37,6 @@ import ( proxyutil "k8s.io/kubernetes/pkg/util/proxy" "github.com/golang/glog" - "k8s.io/kubernetes/third_party/golang/netutil" ) // ProxyHandler provides a http.Handler which will proxy traffic to locations @@ -51,8 +47,6 @@ type ProxyHandler struct { codec runtime.Codec context api.RequestContextMapper apiRequestInfoResolver *APIRequestInfoResolver - - dial func(network, addr string) (net.Conn, error) } func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -125,11 +119,8 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { httpCode = http.StatusNotFound return } - // If we have a custom dialer, and no pre-existing transport, initialize it to use the dialer. - if roundTripper == nil && r.dial != nil { - glog.V(5).Infof("[%x: %v] making a dial-only transport...", proxyHandlerTraceID, req.URL) - roundTripper = &http.Transport{Dial: r.dial} - } else if roundTripper != nil { + + if roundTripper != nil { glog.V(5).Infof("[%x: %v] using transport %T...", proxyHandlerTraceID, req.URL, roundTripper) } @@ -217,7 +208,7 @@ func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Reque if !httpstream.IsUpgradeRequest(req) { return false } - backendConn, err := dialURL(location, transport) + backendConn, err := proxyutil.DialURL(location, transport) if err != nil { status := errToAPIStatus(err) writeJSON(status.Code, r.codec, status, w, true) @@ -264,46 +255,6 @@ func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Reque return true } -func dialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { - dialAddr := netutil.CanonicalAddr(url) - - switch url.Scheme { - case "http": - return net.Dial("tcp", dialAddr) - case "https": - // Get the tls config from the transport if we recognize it - var tlsConfig *tls.Config - if transport != nil { - httpTransport, ok := transport.(*http.Transport) - if ok { - tlsConfig = httpTransport.TLSClientConfig - } - } - - // Dial - tlsConn, err := tls.Dial("tcp", dialAddr, tlsConfig) - if err != nil { - return nil, err - } - - // Return if we were configured to skip validation - if tlsConfig != nil && tlsConfig.InsecureSkipVerify { - return tlsConn, nil - } - - // Verify - host, _, _ := net.SplitHostPort(dialAddr) - if err := tlsConn.VerifyHostname(host); err != nil { - tlsConn.Close() - return nil, err - } - - return tlsConn, nil - default: - return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme) - } -} - // borrowed from net/http/httputil/reverseproxy.go func singleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") diff --git a/pkg/apiserver/proxy_test.go b/pkg/apiserver/proxy_test.go index ddbd6c063c9..3980181e2b4 100644 --- a/pkg/apiserver/proxy_test.go +++ b/pkg/apiserver/proxy_test.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -171,6 +172,21 @@ func TestProxyUpgrade(t *testing.T) { }, ProxyTransport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: localhostPool}}, }, + "https (valid hostname + RootCAs + custom dialer)": { + ServerFunc: func(h http.Handler) *httptest.Server { + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Errorf("https (valid hostname): proxy_test: %v", err) + } + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + ts.StartTLS() + return ts + }, + ProxyTransport: &http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}, + }, } for k, tc := range testcases { diff --git a/pkg/client/chaosclient/chaosclient.go b/pkg/client/chaosclient/chaosclient.go index ec50ce2f625..b74ee80a344 100644 --- a/pkg/client/chaosclient/chaosclient.go +++ b/pkg/client/chaosclient/chaosclient.go @@ -28,6 +28,8 @@ import ( "net/http" "reflect" "runtime" + + "k8s.io/kubernetes/pkg/util" ) // chaosrt provides the ability to perform simulations of HTTP client failures @@ -86,6 +88,12 @@ func (rt *chaosrt) RoundTrip(req *http.Request) (*http.Response, error) { return rt.rt.RoundTrip(req) } +var _ = util.RoundTripperWrapper(&chaosrt{}) + +func (rt *chaosrt) WrappedRoundTripper() http.RoundTripper { + return rt.rt +} + // Seed represents a consistent stream of chaos. type Seed struct { *rand.Rand diff --git a/pkg/client/unversioned/debugging.go b/pkg/client/unversioned/debugging.go index df43e8984d0..76f8c2caa58 100644 --- a/pkg/client/unversioned/debugging.go +++ b/pkg/client/unversioned/debugging.go @@ -23,6 +23,7 @@ import ( "github.com/golang/glog" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/sets" ) @@ -133,3 +134,9 @@ func (rt *DebuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e return response, err } + +var _ = util.RoundTripperWrapper(&DebuggingRoundTripper{}) + +func (rt *DebuggingRoundTripper) WrappedRoundTripper() http.RoundTripper { + return rt.delegatedRoundTripper +} diff --git a/pkg/client/unversioned/transport.go b/pkg/client/unversioned/transport.go index ecb73dc1e7a..f31b1a47e12 100644 --- a/pkg/client/unversioned/transport.go +++ b/pkg/client/unversioned/transport.go @@ -22,6 +22,8 @@ import ( "fmt" "io/ioutil" "net/http" + + "k8s.io/kubernetes/pkg/util" ) type userAgentRoundTripper struct { @@ -42,6 +44,12 @@ func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, e return rt.rt.RoundTrip(req) } +var _ = util.RoundTripperWrapper(&userAgentRoundTripper{}) + +func (rt *userAgentRoundTripper) WrappedRoundTripper() http.RoundTripper { + return rt.rt +} + type basicAuthRoundTripper struct { username string password string @@ -63,6 +71,12 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e return rt.rt.RoundTrip(req) } +var _ = util.RoundTripperWrapper(&basicAuthRoundTripper{}) + +func (rt *basicAuthRoundTripper) WrappedRoundTripper() http.RoundTripper { + return rt.rt +} + type bearerAuthRoundTripper struct { bearer string rt http.RoundTripper @@ -84,6 +98,12 @@ func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, return rt.rt.RoundTrip(req) } +var _ = util.RoundTripperWrapper(&bearerAuthRoundTripper{}) + +func (rt *bearerAuthRoundTripper) WrappedRoundTripper() http.RoundTripper { + return rt.rt +} + // TLSConfigFor returns a tls.Config that will provide the transport level security defined // by the provided Config. Will return nil if no transport level security is requested. func TLSConfigFor(config *Config) (*tls.Config, error) { diff --git a/pkg/master/master.go b/pkg/master/master.go index 439f04609b9..b4ef96b2f5b 100644 --- a/pkg/master/master.go +++ b/pkg/master/master.go @@ -17,18 +17,15 @@ limitations under the License. package master import ( + "crypto/tls" "fmt" - "io/ioutil" - "math/rand" "net" "net/http" "net/http/pprof" "net/url" - "os" "strconv" "strings" "sync" - "sync/atomic" "time" "k8s.io/kubernetes/pkg/admission" @@ -240,10 +237,12 @@ type Config struct { // The range of ports to be assigned to services with type=NodePort or greater ServiceNodePortRange util.PortRange - // Used for secure proxy. If empty, don't use secure proxy. - SSHUser string - SSHKeyfile string - InstallSSHKey InstallSSHKey + // Used to customize default proxy dial/tls options + ProxyDialer apiserver.ProxyDialerFunc + ProxyTLSClientConfig *tls.Config + + // Used to start and monitor tunneling + Tunneler Tunneler KubernetesServiceNodePort int } @@ -305,14 +304,11 @@ type Master struct { Handler http.Handler InsecureHandler http.Handler - // Used for secure proxy - dialer apiserver.ProxyDialerFunc - tunnels *util.SSHTunnelList - tunnelsLock sync.Mutex - installSSHKey InstallSSHKey - lastSync int64 // Seconds since Epoch - lastSyncMetric prometheus.GaugeFunc - clock util.Clock + // Used for custom proxy dialing, and proxy TLS options + proxyTransport http.RoundTripper + + // Used to start and monitor tunneling + tunneler Tunneler // storage for third party objects thirdPartyStorage storage.Interface @@ -453,7 +449,8 @@ func New(c *Config) *Master { // TODO: serviceReadWritePort should be passed in as an argument, it may not always be 443 serviceReadWritePort: 443, - installSSHKey: c.InstallSSHKey, + tunneler: c.Tunneler, + KubernetesServiceNodePort: c.KubernetesServiceNodePort, } @@ -505,10 +502,18 @@ func NewHandlerContainer(mux *http.ServeMux) *restful.Container { // init initializes master. func (m *Master) init(c *Config) { + + if c.ProxyDialer != nil || c.ProxyTLSClientConfig != nil { + m.proxyTransport = util.SetTransportDefaults(&http.Transport{ + Dial: c.ProxyDialer, + TLSClientConfig: c.ProxyTLSClientConfig, + }) + } + healthzChecks := []healthz.HealthzChecker{} - m.clock = util.RealClock{} + dbClient := func(resource string) storage.Interface { return c.StorageDestinations.get("", resource) } - podStorage := podetcd.NewStorage(dbClient("pods"), c.EnableWatchCache, c.KubeletClient) + podStorage := podetcd.NewStorage(dbClient("pods"), c.EnableWatchCache, c.KubeletClient, m.proxyTransport) podTemplateStorage := podtemplateetcd.NewREST(dbClient("podTemplates")) @@ -527,7 +532,7 @@ func (m *Master) init(c *Config) { endpointsStorage := endpointsetcd.NewREST(dbClient("endpoints"), c.EnableWatchCache) m.endpointRegistry = endpoint.NewRegistry(endpointsStorage) - nodeStorage, nodeStatusStorage := nodeetcd.NewREST(dbClient("nodes"), c.EnableWatchCache, c.KubeletClient) + nodeStorage, nodeStatusStorage := nodeetcd.NewREST(dbClient("nodes"), c.EnableWatchCache, c.KubeletClient, m.proxyTransport) m.nodeRegistry = node.NewRegistry(nodeStorage) serviceStorage := serviceetcd.NewREST(dbClient("services")) @@ -569,7 +574,7 @@ func (m *Master) init(c *Config) { "replicationControllers": controllerStorage, "replicationControllers/status": controllerStatusStorage, - "services": service.NewStorage(m.serviceRegistry, m.endpointRegistry, serviceClusterIPAllocator, serviceNodePortAllocator), + "services": service.NewStorage(m.serviceRegistry, m.endpointRegistry, serviceClusterIPAllocator, serviceNodePortAllocator, m.proxyTransport), "endpoints": endpointsStorage, "nodes": nodeStorage, "nodes/status": nodeStatusStorage, @@ -591,51 +596,13 @@ func (m *Master) init(c *Config) { "componentStatuses": componentstatus.NewStorage(func() map[string]apiserver.Server { return m.getServersToValidate(c) }), } - // establish the node proxy dialer - if len(c.SSHUser) > 0 { - // Usernames are capped @ 32 - if len(c.SSHUser) > 32 { - glog.Warning("SSH User is too long, truncating to 32 chars") - c.SSHUser = c.SSHUser[0:32] - } - glog.Infof("Setting up proxy: %s %s", c.SSHUser, c.SSHKeyfile) - - // public keyfile is written last, so check for that. - publicKeyFile := c.SSHKeyfile + ".pub" - exists, err := util.FileExists(publicKeyFile) - if err != nil { - glog.Errorf("Error detecting if key exists: %v", err) - } else if !exists { - glog.Infof("Key doesn't exist, attempting to create") - err := m.generateSSHKey(c.SSHUser, c.SSHKeyfile, publicKeyFile) - if err != nil { - glog.Errorf("Failed to create key pair: %v", err) - } - } - m.tunnels = &util.SSHTunnelList{} - m.dialer = m.Dial - m.setupSecureProxy(c.SSHUser, c.SSHKeyfile, publicKeyFile) - m.lastSync = m.clock.Now().Unix() - - // This is pretty ugly. A better solution would be to pull this all the way up into the - // server.go file. - httpKubeletClient, ok := c.KubeletClient.(*client.HTTPKubeletClient) - if ok { - httpKubeletClient.Config.Dial = m.dialer - transport, err := client.MakeTransport(httpKubeletClient.Config) - if err != nil { - glog.Errorf("Error setting up transport over SSH: %v", err) - } else { - httpKubeletClient.Client.Transport = transport - } - } else { - glog.Errorf("Failed to cast %v to HTTPKubeletClient, skipping SSH tunnel.", c.KubeletClient) - } + if m.tunneler != nil { + m.tunneler.Run(m.getNodeAddresses) healthzChecks = append(healthzChecks, healthz.NamedCheck("SSH Tunnel Check", m.IsTunnelSyncHealthy)) - m.lastSyncMetric = prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "apiserver_proxy_tunnel_sync_latency_secs", Help: "The time since the last successful synchronization of the SSH tunnels for proxy requests.", - }, func() float64 { return float64(m.secondsSinceSync()) }) + }, func() float64 { return float64(m.tunneler.SecondsSinceSync()) }) } apiVersions := []string{} @@ -875,7 +842,6 @@ func (m *Master) defaultAPIGroupVersion() *apiserver.APIGroupVersion { Admit: m.admissionControl, Context: m.requestContextMapper, - ProxyDialerFn: m.dialer, MinRequestTimeout: m.minRequestTimeout, } } @@ -1031,7 +997,6 @@ func (m *Master) thirdpartyapi(group, kind, version string) *apiserver.APIGroupV Context: m.requestContextMapper, - ProxyDialerFn: m.dialer, MinRequestTimeout: m.minRequestTimeout, } } @@ -1094,7 +1059,6 @@ func (m *Master) experimental(c *Config) *apiserver.APIGroupVersion { Admit: m.admissionControl, Context: m.requestContextMapper, - ProxyDialerFn: m.dialer, MinRequestTimeout: m.minRequestTimeout, } } @@ -1117,41 +1081,6 @@ func findExternalAddress(node *api.Node) (string, error) { return "", fmt.Errorf("Couldn't find external address: %v", node) } -func (m *Master) Dial(net, addr string) (net.Conn, error) { - // Only lock while picking a tunnel. - tunnel, err := func() (util.SSHTunnelEntry, error) { - m.tunnelsLock.Lock() - defer m.tunnelsLock.Unlock() - return m.tunnels.PickRandomTunnel() - }() - if err != nil { - return nil, err - } - - start := time.Now() - id := rand.Int63() // So you can match begins/ends in the log. - glog.V(3).Infof("[%x: %v] Dialing...", id, tunnel.Address) - defer func() { - glog.V(3).Infof("[%x: %v] Dialed in %v.", id, tunnel.Address, time.Now().Sub(start)) - }() - return tunnel.Tunnel.Dial(net, addr) -} - -func (m *Master) needToReplaceTunnels(addrs []string) bool { - m.tunnelsLock.Lock() - defer m.tunnelsLock.Unlock() - if m.tunnels == nil || m.tunnels.Len() != len(addrs) { - return true - } - // TODO (cjcullen): This doesn't need to be n^2 - for ix := range addrs { - if !m.tunnels.Has(addrs[ix]) { - return true - } - } - return false -} - func (m *Master) getNodeAddresses() ([]string, error) { nodes, err := m.nodeRegistry.ListNodes(api.NewDefaultContext(), labels.Everything(), fields.Everything()) if err != nil { @@ -1170,126 +1099,12 @@ func (m *Master) getNodeAddresses() ([]string, error) { } func (m *Master) IsTunnelSyncHealthy(req *http.Request) error { - lag := m.secondsSinceSync() + if m.tunneler == nil { + return nil + } + lag := m.tunneler.SecondsSinceSync() if lag > 600 { return fmt.Errorf("Tunnel sync is taking to long: %d", lag) } return nil } - -func (m *Master) secondsSinceSync() int64 { - now := m.clock.Now().Unix() - then := atomic.LoadInt64(&m.lastSync) - return now - then -} - -func (m *Master) replaceTunnels(user, keyfile string, newAddrs []string) error { - glog.Infof("replacing tunnels. New addrs: %v", newAddrs) - tunnels := util.MakeSSHTunnels(user, keyfile, newAddrs) - if err := tunnels.Open(); err != nil { - return err - } - m.tunnelsLock.Lock() - defer m.tunnelsLock.Unlock() - if m.tunnels != nil { - m.tunnels.Close() - } - m.tunnels = tunnels - atomic.StoreInt64(&m.lastSync, m.clock.Now().Unix()) - return nil -} - -func (m *Master) loadTunnels(user, keyfile string) error { - addrs, err := m.getNodeAddresses() - if err != nil { - return err - } - if !m.needToReplaceTunnels(addrs) { - return nil - } - // TODO: This is going to unnecessarily close connections to unchanged nodes. - // See comment about using Watch above. - glog.Info("found different nodes. Need to replace tunnels") - return m.replaceTunnels(user, keyfile, addrs) -} - -func (m *Master) refreshTunnels(user, keyfile string) error { - addrs, err := m.getNodeAddresses() - if err != nil { - return err - } - return m.replaceTunnels(user, keyfile, addrs) -} - -func (m *Master) setupSecureProxy(user, privateKeyfile, publicKeyfile string) { - // Sync loop to ensure that the SSH key has been installed. - go util.Until(func() { - if m.installSSHKey == nil { - glog.Error("Won't attempt to install ssh key: installSSHKey function is nil") - return - } - key, err := util.ParsePublicKeyFromFile(publicKeyfile) - if err != nil { - glog.Errorf("Failed to load public key: %v", err) - return - } - keyData, err := util.EncodeSSHKey(key) - if err != nil { - glog.Errorf("Failed to encode public key: %v", err) - return - } - if err := m.installSSHKey(user, keyData); err != nil { - glog.Errorf("Failed to install ssh key: %v", err) - } - }, 5*time.Minute, util.NeverStop) - // Sync loop for tunnels - // TODO: switch this to watch. - go util.Until(func() { - if err := m.loadTunnels(user, privateKeyfile); err != nil { - glog.Errorf("Failed to load SSH Tunnels: %v", err) - } - if m.tunnels != nil && m.tunnels.Len() != 0 { - // Sleep for 10 seconds if we have some tunnels. - // TODO (cjcullen): tunnels can lag behind actually existing nodes. - time.Sleep(9 * time.Second) - } - }, 1*time.Second, util.NeverStop) - // Refresh loop for tunnels - // TODO: could make this more controller-ish - go util.Until(func() { - time.Sleep(5 * time.Minute) - if err := m.refreshTunnels(user, privateKeyfile); err != nil { - glog.Errorf("Failed to refresh SSH Tunnels: %v", err) - } - }, 0*time.Second, util.NeverStop) -} - -func (m *Master) generateSSHKey(user, privateKeyfile, publicKeyfile string) error { - // TODO: user is not used. Consider removing it as an input to the function. - private, public, err := util.GenerateKey(2048) - if err != nil { - return err - } - // If private keyfile already exists, we must have only made it halfway - // through last time, so delete it. - exists, err := util.FileExists(privateKeyfile) - if err != nil { - glog.Errorf("Error detecting if private key exists: %v", err) - } else if exists { - glog.Infof("Private key exists, but public key does not") - if err := os.Remove(privateKeyfile); err != nil { - glog.Errorf("Failed to remove stale private key: %v", err) - } - } - if err := ioutil.WriteFile(privateKeyfile, util.EncodePrivateKey(private), 0600); err != nil { - return err - } - publicKeyBytes, err := util.EncodePublicKey(public) - if err != nil { - return err - } - if err := ioutil.WriteFile(publicKeyfile+".tmp", publicKeyBytes, 0600); err != nil { - return err - } - return os.Rename(publicKeyfile+".tmp", publicKeyfile) -} diff --git a/pkg/master/master_test.go b/pkg/master/master_test.go index f00f76b9de0..bdb9a79928b 100644 --- a/pkg/master/master_test.go +++ b/pkg/master/master_test.go @@ -18,6 +18,7 @@ package master import ( "bytes" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -25,12 +26,9 @@ import ( "net" "net/http" "net/http/httptest" - "os" - "path/filepath" "reflect" "strings" "testing" - "time" "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/api/latest" @@ -81,7 +79,12 @@ func setUp(t *testing.T) (Master, Config, *assert.Assertions) { // using the configuration properly. func TestNew(t *testing.T) { _, config, assert := setUp(t) + config.KubeletClient = client.FakeKubeletClient{} + + config.ProxyDialer = func(network, addr string) (net.Conn, error) { return nil, nil } + config.ProxyTLSClientConfig = &tls.Config{} + master := New(&config) // Verify many of the variables match their config counterparts @@ -106,7 +109,15 @@ func TestNew(t *testing.T) { assert.Equal(master.clusterIP, config.PublicAddress) assert.Equal(master.publicReadWritePort, config.ReadWritePort) assert.Equal(master.serviceReadWriteIP, config.ServiceReadWriteIP) - assert.Equal(master.installSSHKey, config.InstallSSHKey) + assert.Equal(master.tunneler, config.Tunneler) + + // These functions should point to the same memory location + masterDialer, _ := util.Dialer(master.proxyTransport) + masterDialerFunc := fmt.Sprintf("%p", masterDialer) + configDialerFunc := fmt.Sprintf("%p", config.ProxyDialer) + assert.Equal(masterDialerFunc, configDialerFunc) + + assert.Equal(master.proxyTransport.(*http.Transport).TLSClientConfig, config.ProxyTLSClientConfig) } // TestNewEtcdStorage verifies that the usage of NewEtcdStorage reacts properly when @@ -271,7 +282,6 @@ func TestInstallSwaggerAPI(t *testing.T) { // creates the expected APIGroupVersion based off of master. func TestDefaultAPIGroupVersion(t *testing.T) { master, _, assert := setUp(t) - master.dialer = func(network, addr string) (net.Conn, error) { return nil, nil } apiGroup := master.defaultAPIGroupVersion() @@ -279,11 +289,6 @@ func TestDefaultAPIGroupVersion(t *testing.T) { assert.Equal(apiGroup.Admit, master.admissionControl) assert.Equal(apiGroup.Context, master.requestContextMapper) assert.Equal(apiGroup.MinRequestTimeout, master.minRequestTimeout) - - // These functions should be different instances of the same function - groupDialerFunc := fmt.Sprintf("%+v", apiGroup.ProxyDialerFn) - masterDialerFunc := fmt.Sprintf("%+v", master.dialer) - assert.Equal(groupDialerFunc, masterDialerFunc) } // TestExpapi verifies that the unexported exapi creates @@ -299,42 +304,6 @@ func TestExpapi(t *testing.T) { assert.Equal(expAPIGroup.Version, latest.GroupOrDie("extensions").GroupVersion) } -// TestSecondsSinceSync verifies that proper results are returned -// when checking the time between syncs -func TestSecondsSinceSync(t *testing.T) { - master, _, assert := setUp(t) - master.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix() - - // Nano Second. No difference. - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 2, time.UTC)} - assert.Equal(int64(0), master.secondsSinceSync()) - - // Second - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 2, 1, time.UTC)} - assert.Equal(int64(1), master.secondsSinceSync()) - - // Minute - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 2, 1, 1, time.UTC)} - assert.Equal(int64(60), master.secondsSinceSync()) - - // Hour - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 2, 1, 1, 1, time.UTC)} - assert.Equal(int64(3600), master.secondsSinceSync()) - - // Day - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 2, 1, 1, 1, 1, time.UTC)} - assert.Equal(int64(86400), master.secondsSinceSync()) - - // Month - master.clock = &util.FakeClock{Time: time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC)} - assert.Equal(int64(2678400), master.secondsSinceSync()) - - // Future Month. Should be -Month. - master.lastSync = time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC).Unix() - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC)} - assert.Equal(int64(-2678400), master.secondsSinceSync()) -} - // TestGetNodeAddresses verifies that proper results are returned // when requesting node addresses. func TestGetNodeAddresses(t *testing.T) { @@ -366,73 +335,6 @@ func TestGetNodeAddresses(t *testing.T) { assert.Equal([]string{"127.0.0.2", "127.0.0.2"}, addrs) } -// TestRefreshTunnels verifies that the function errors when no addresses -// are associated with nodes -func TestRefreshTunnels(t *testing.T) { - master, _, assert := setUp(t) - - // Fail case (no addresses associated with nodes) - assert.Error(master.refreshTunnels("test", "/tmp/undefined")) - - // TODO: pass case without needing actual connections? -} - -// TestIsTunnelSyncHealthy verifies that the 600 second lag test -// is honored. -func TestIsTunnelSyncHealthy(t *testing.T) { - master, _, assert := setUp(t) - - // Pass case: 540 second lag - master.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix() - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 9, 1, 1, time.UTC)} - err := master.IsTunnelSyncHealthy(nil) - assert.NoError(err, "IsTunnelSyncHealthy() should not have returned an error.") - - // Fail case: 720 second lag - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 12, 1, 1, time.UTC)} - err = master.IsTunnelSyncHealthy(nil) - assert.Error(err, "IsTunnelSyncHealthy() should have returned an error.") -} - -// generateTempFile creates a temporary file path -func generateTempFilePath(prefix string) string { - tmpPath, _ := filepath.Abs(fmt.Sprintf("%s/%s-%d", os.TempDir(), prefix, time.Now().Unix())) - return tmpPath -} - -// TestGenerateSSHKey verifies that SSH key generation does indeed -// generate keys even with keys already exist. -func TestGenerateSSHKey(t *testing.T) { - master, _, assert := setUp(t) - - privateKey := generateTempFilePath("private") - publicKey := generateTempFilePath("public") - - // Make sure we have no test keys laying around - os.Remove(privateKey) - os.Remove(publicKey) - - // Pass case: Sunny day case - err := master.generateSSHKey("unused", privateKey, publicKey) - assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) - - // Pass case: PrivateKey exists test case - os.Remove(publicKey) - err = master.generateSSHKey("unused", privateKey, publicKey) - assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) - - // Pass case: PublicKey exists test case - os.Remove(privateKey) - err = master.generateSSHKey("unused", privateKey, publicKey) - assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) - - // Make sure we have no test keys laying around - os.Remove(privateKey) - os.Remove(publicKey) - - // TODO: testing error cases where the file can not be removed? -} - func TestDiscoveryAtAPIS(t *testing.T) { master, config, assert := setUp(t) master.exp = true diff --git a/pkg/master/tunneler.go b/pkg/master/tunneler.go new file mode 100644 index 00000000000..f07ddde0785 --- /dev/null +++ b/pkg/master/tunneler.go @@ -0,0 +1,262 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package master + +import ( + "io/ioutil" + "math/rand" + "net" + "os" + "sync" + "sync/atomic" + "time" + + "k8s.io/kubernetes/pkg/util" + + "github.com/golang/glog" + "github.com/prometheus/client_golang/prometheus" +) + +type AddressFunc func() (addresses []string, err error) + +type Tunneler interface { + Run(AddressFunc) + Stop() + Dial(net, addr string) (net.Conn, error) + SecondsSinceSync() int64 +} + +type SSHTunneler struct { + SSHUser string + SSHKeyfile string + InstallSSHKey InstallSSHKey + + tunnels *util.SSHTunnelList + tunnelsLock sync.Mutex + lastSync int64 // Seconds since Epoch + lastSyncMetric prometheus.GaugeFunc + clock util.Clock + + getAddresses AddressFunc + stopChan chan struct{} +} + +func NewSSHTunneler(sshUser string, sshKeyfile string, installSSHKey InstallSSHKey) Tunneler { + return &SSHTunneler{ + SSHUser: sshUser, + SSHKeyfile: sshKeyfile, + InstallSSHKey: installSSHKey, + + clock: util.RealClock{}, + } +} + +// Run establishes tunnel loops and returns +func (c *SSHTunneler) Run(getAddresses AddressFunc) { + if c.stopChan != nil { + return + } + c.stopChan = make(chan struct{}) + + // Save the address getter + if getAddresses != nil { + c.getAddresses = getAddresses + } + + // Usernames are capped @ 32 + if len(c.SSHUser) > 32 { + glog.Warning("SSH User is too long, truncating to 32 chars") + c.SSHUser = c.SSHUser[0:32] + } + glog.Infof("Setting up proxy: %s %s", c.SSHUser, c.SSHKeyfile) + + // public keyfile is written last, so check for that. + publicKeyFile := c.SSHKeyfile + ".pub" + exists, err := util.FileExists(publicKeyFile) + if err != nil { + glog.Errorf("Error detecting if key exists: %v", err) + } else if !exists { + glog.Infof("Key doesn't exist, attempting to create") + err := c.generateSSHKey(c.SSHUser, c.SSHKeyfile, publicKeyFile) + if err != nil { + glog.Errorf("Failed to create key pair: %v", err) + } + } + c.tunnels = &util.SSHTunnelList{} + c.setupSecureProxy(c.SSHUser, c.SSHKeyfile, publicKeyFile) + c.lastSync = c.clock.Now().Unix() +} + +// Stop gracefully shuts down the tunneler +func (c *SSHTunneler) Stop() { + if c.stopChan != nil { + close(c.stopChan) + c.stopChan = nil + } +} + +func (c *SSHTunneler) Dial(net, addr string) (net.Conn, error) { + // Only lock while picking a tunnel. + tunnel, err := func() (util.SSHTunnelEntry, error) { + c.tunnelsLock.Lock() + defer c.tunnelsLock.Unlock() + return c.tunnels.PickRandomTunnel() + }() + if err != nil { + return nil, err + } + + start := time.Now() + id := rand.Int63() // So you can match begins/ends in the log. + glog.V(3).Infof("[%x: %v] Dialing...", id, tunnel.Address) + defer func() { + glog.V(3).Infof("[%x: %v] Dialed in %v.", id, tunnel.Address, time.Now().Sub(start)) + }() + return tunnel.Tunnel.Dial(net, addr) +} + +func (c *SSHTunneler) SecondsSinceSync() int64 { + now := c.clock.Now().Unix() + then := atomic.LoadInt64(&c.lastSync) + return now - then +} + +func (c *SSHTunneler) needToReplaceTunnels(addrs []string) bool { + c.tunnelsLock.Lock() + defer c.tunnelsLock.Unlock() + if c.tunnels == nil || c.tunnels.Len() != len(addrs) { + return true + } + // TODO (cjcullen): This doesn't need to be n^2 + for ix := range addrs { + if !c.tunnels.Has(addrs[ix]) { + return true + } + } + return false +} + +func (c *SSHTunneler) replaceTunnels(user, keyfile string, newAddrs []string) error { + glog.Infof("replacing tunnels. New addrs: %v", newAddrs) + tunnels := util.MakeSSHTunnels(user, keyfile, newAddrs) + if err := tunnels.Open(); err != nil { + return err + } + c.tunnelsLock.Lock() + defer c.tunnelsLock.Unlock() + if c.tunnels != nil { + c.tunnels.Close() + } + c.tunnels = tunnels + atomic.StoreInt64(&c.lastSync, c.clock.Now().Unix()) + return nil +} + +func (c *SSHTunneler) loadTunnels(user, keyfile string) error { + addrs, err := c.getAddresses() + if err != nil { + return err + } + if !c.needToReplaceTunnels(addrs) { + return nil + } + // TODO: This is going to unnecessarily close connections to unchanged nodes. + // See comment about using Watch above. + glog.Info("found different nodes. Need to replace tunnels") + return c.replaceTunnels(user, keyfile, addrs) +} + +func (c *SSHTunneler) refreshTunnels(user, keyfile string) error { + addrs, err := c.getAddresses() + if err != nil { + return err + } + return c.replaceTunnels(user, keyfile, addrs) +} + +func (c *SSHTunneler) setupSecureProxy(user, privateKeyfile, publicKeyfile string) { + // Sync loop to ensure that the SSH key has been installed. + go util.Until(func() { + if c.InstallSSHKey == nil { + glog.Error("Won't attempt to install ssh key: InstallSSHKey function is nil") + return + } + key, err := util.ParsePublicKeyFromFile(publicKeyfile) + if err != nil { + glog.Errorf("Failed to load public key: %v", err) + return + } + keyData, err := util.EncodeSSHKey(key) + if err != nil { + glog.Errorf("Failed to encode public key: %v", err) + return + } + if err := c.InstallSSHKey(user, keyData); err != nil { + glog.Errorf("Failed to install ssh key: %v", err) + } + }, 5*time.Minute, c.stopChan) + // Sync loop for tunnels + // TODO: switch this to watch. + go util.Until(func() { + if err := c.loadTunnels(user, privateKeyfile); err != nil { + glog.Errorf("Failed to load SSH Tunnels: %v", err) + } + if c.tunnels != nil && c.tunnels.Len() != 0 { + // Sleep for 10 seconds if we have some tunnels. + // TODO (cjcullen): tunnels can lag behind actually existing nodes. + time.Sleep(9 * time.Second) + } + }, 1*time.Second, c.stopChan) + // Refresh loop for tunnels + // TODO: could make this more controller-ish + go util.Until(func() { + time.Sleep(5 * time.Minute) + if err := c.refreshTunnels(user, privateKeyfile); err != nil { + glog.Errorf("Failed to refresh SSH Tunnels: %v", err) + } + }, 0*time.Second, c.stopChan) +} + +func (c *SSHTunneler) generateSSHKey(user, privateKeyfile, publicKeyfile string) error { + // TODO: user is not used. Consider removing it as an input to the function. + private, public, err := util.GenerateKey(2048) + if err != nil { + return err + } + // If private keyfile already exists, we must have only made it halfway + // through last time, so delete it. + exists, err := util.FileExists(privateKeyfile) + if err != nil { + glog.Errorf("Error detecting if private key exists: %v", err) + } else if exists { + glog.Infof("Private key exists, but public key does not") + if err := os.Remove(privateKeyfile); err != nil { + glog.Errorf("Failed to remove stale private key: %v", err) + } + } + if err := ioutil.WriteFile(privateKeyfile, util.EncodePrivateKey(private), 0600); err != nil { + return err + } + publicKeyBytes, err := util.EncodePublicKey(public) + if err != nil { + return err + } + if err := ioutil.WriteFile(publicKeyfile+".tmp", publicKeyBytes, 0600); err != nil { + return err + } + return os.Rename(publicKeyfile+".tmp", publicKeyfile) +} diff --git a/pkg/master/tunneler_test.go b/pkg/master/tunneler_test.go new file mode 100644 index 00000000000..7b348a525b3 --- /dev/null +++ b/pkg/master/tunneler_test.go @@ -0,0 +1,139 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package master + +import ( + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "k8s.io/kubernetes/pkg/util" + + "github.com/stretchr/testify/assert" +) + +// TestSecondsSinceSync verifies that proper results are returned +// when checking the time between syncs +func TestSecondsSinceSync(t *testing.T) { + tunneler := &SSHTunneler{} + assert := assert.New(t) + + tunneler.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix() + + // Nano Second. No difference. + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 2, time.UTC)} + assert.Equal(int64(0), tunneler.SecondsSinceSync()) + + // Second + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 2, 1, time.UTC)} + assert.Equal(int64(1), tunneler.SecondsSinceSync()) + + // Minute + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 2, 1, 1, time.UTC)} + assert.Equal(int64(60), tunneler.SecondsSinceSync()) + + // Hour + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 2, 1, 1, 1, time.UTC)} + assert.Equal(int64(3600), tunneler.SecondsSinceSync()) + + // Day + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 2, 1, 1, 1, 1, time.UTC)} + assert.Equal(int64(86400), tunneler.SecondsSinceSync()) + + // Month + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC)} + assert.Equal(int64(2678400), tunneler.SecondsSinceSync()) + + // Future Month. Should be -Month. + tunneler.lastSync = time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC).Unix() + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC)} + assert.Equal(int64(-2678400), tunneler.SecondsSinceSync()) +} + +// TestRefreshTunnels verifies that the function errors when no addresses +// are associated with nodes +func TestRefreshTunnels(t *testing.T) { + tunneler := &SSHTunneler{} + tunneler.getAddresses = func() ([]string, error) { return []string{}, nil } + assert := assert.New(t) + + // Fail case (no addresses associated with nodes) + assert.Error(tunneler.refreshTunnels("test", "/tmp/undefined")) + + // TODO: pass case without needing actual connections? +} + +// TestIsTunnelSyncHealthy verifies that the 600 second lag test +// is honored. +func TestIsTunnelSyncHealthy(t *testing.T) { + tunneler := &SSHTunneler{} + master, _, assert := setUp(t) + master.tunneler = tunneler + + // Pass case: 540 second lag + tunneler.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix() + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 9, 1, 1, time.UTC)} + err := master.IsTunnelSyncHealthy(nil) + assert.NoError(err, "IsTunnelSyncHealthy() should not have returned an error.") + + // Fail case: 720 second lag + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 12, 1, 1, time.UTC)} + err = master.IsTunnelSyncHealthy(nil) + assert.Error(err, "IsTunnelSyncHealthy() should have returned an error.") +} + +// generateTempFile creates a temporary file path +func generateTempFilePath(prefix string) string { + tmpPath, _ := filepath.Abs(fmt.Sprintf("%s/%s-%d", os.TempDir(), prefix, time.Now().Unix())) + return tmpPath +} + +// TestGenerateSSHKey verifies that SSH key generation does indeed +// generate keys even with keys already exist. +func TestGenerateSSHKey(t *testing.T) { + tunneler := &SSHTunneler{} + assert := assert.New(t) + + privateKey := generateTempFilePath("private") + publicKey := generateTempFilePath("public") + + // Make sure we have no test keys laying around + os.Remove(privateKey) + os.Remove(publicKey) + + // Pass case: Sunny day case + err := tunneler.generateSSHKey("unused", privateKey, publicKey) + assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) + + // Pass case: PrivateKey exists test case + os.Remove(publicKey) + err = tunneler.generateSSHKey("unused", privateKey, publicKey) + assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) + + // Pass case: PublicKey exists test case + os.Remove(privateKey) + err = tunneler.generateSSHKey("unused", privateKey, publicKey) + assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) + + // Make sure we have no test keys laying around + os.Remove(privateKey) + os.Remove(publicKey) + + // TODO: testing error cases where the file can not be removed? +} diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index 14bd0e0e27d..cc4d74440f7 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -17,10 +17,7 @@ limitations under the License. package rest import ( - "crypto/tls" - "fmt" "io" - "net" "net/http" "net/http/httputil" "net/url" @@ -29,12 +26,12 @@ import ( "time" "k8s.io/kubernetes/pkg/api/errors" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/proxy" "github.com/golang/glog" "github.com/mxk/go-flowrate/flowrate" - "k8s.io/kubernetes/third_party/golang/netutil" ) // UpgradeAwareProxyHandler is a handler for proxy requests that may require an upgrade @@ -128,7 +125,7 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return false } - backendConn, err := h.dialURL() + backendConn, err := proxy.DialURL(h.Location, h.Transport) if err != nil { h.err = err return true @@ -189,79 +186,6 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return true } -func (h *UpgradeAwareProxyHandler) dialURL() (net.Conn, error) { - dialAddr := netutil.CanonicalAddr(h.Location) - - var dialer func(network, addr string) (net.Conn, error) - if httpTransport, ok := h.Transport.(*http.Transport); ok && httpTransport.Dial != nil { - dialer = httpTransport.Dial - } - - switch h.Location.Scheme { - case "http": - if dialer != nil { - return dialer("tcp", dialAddr) - } - return net.Dial("tcp", dialAddr) - case "https": - // TODO: this TLS logic can probably be cleaned up; it's messy in an attempt - // to preserve behavior that we don't know for sure is exercised. - - // Get the tls config from the transport if we recognize it - var tlsConfig *tls.Config - var tlsConn *tls.Conn - var err error - if h.Transport != nil { - httpTransport, ok := h.Transport.(*http.Transport) - if ok { - tlsConfig = httpTransport.TLSClientConfig - } - } - if dialer != nil { - // We have a dialer; use it to open the connection, then - // create a tls client using the connection. - netConn, err := dialer("tcp", dialAddr) - if err != nil { - return nil, err - } - // tls.Client requires non-nil config - if tlsConfig == nil { - glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify") - tlsConfig = &tls.Config{ - InsecureSkipVerify: true, - } - } - tlsConn = tls.Client(netConn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { - return nil, err - } - - } else { - // Dial - tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig) - if err != nil { - return nil, err - } - } - - // Return if we were configured to skip validation - if tlsConfig != nil && tlsConfig.InsecureSkipVerify { - return tlsConn, nil - } - - // Verify - host, _, _ := net.SplitHostPort(dialAddr) - if err := tlsConn.VerifyHostname(host); err != nil { - tlsConn.Close() - return nil, err - } - - return tlsConn, nil - default: - return nil, fmt.Errorf("unknown scheme: %s", h.Location.Scheme) - } -} - func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper { scheme := url.Scheme host := url.Host @@ -294,7 +218,12 @@ func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, er } removeCORSHeaders(resp) return resp, nil +} +var _ = util.RoundTripperWrapper(&corsRemovingTransport{}) + +func (rt *corsRemovingTransport) WrappedRoundTripper() http.RoundTripper { + return rt.RoundTripper } // removeCORSHeaders strip CORS headers sent from the backend diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index 1f20f3dac87..829b1a6d5ef 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -305,6 +306,21 @@ func TestProxyUpgrade(t *testing.T) { }, ProxyTransport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: localhostPool}}, }, + "https (valid hostname + RootCAs + custom dialer)": { + ServerFunc: func(h http.Handler) *httptest.Server { + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Errorf("https (valid hostname): proxy_test: %v", err) + } + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + ts.StartTLS() + return ts + }, + ProxyTransport: &http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}, + }, } for k, tc := range testcases { diff --git a/pkg/registry/node/etcd/etcd.go b/pkg/registry/node/etcd/etcd.go index ec66b5cf46f..2b83d10f08b 100644 --- a/pkg/registry/node/etcd/etcd.go +++ b/pkg/registry/node/etcd/etcd.go @@ -31,7 +31,8 @@ import ( type REST struct { *etcdgeneric.Etcd - connection client.ConnectionInfoGetter + connection client.ConnectionInfoGetter + proxyTransport http.RoundTripper } // StatusREST implements the REST endpoint for changing the status of a pod. @@ -49,7 +50,7 @@ func (r *StatusREST) Update(ctx api.Context, obj runtime.Object) (runtime.Object } // NewREST returns a RESTStorage object that will work against nodes. -func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionInfoGetter) (*REST, *StatusREST) { +func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionInfoGetter, proxyTransport http.RoundTripper) (*REST, *StatusREST) { prefix := "/minions" storageInterface := s @@ -91,7 +92,7 @@ func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionIn statusStore := *store statusStore.UpdateStrategy = node.StatusStrategy - return &REST{store, connection}, &StatusREST{store: &statusStore} + return &REST{store, connection, proxyTransport}, &StatusREST{store: &statusStore} } // Implement Redirector. @@ -99,5 +100,5 @@ var _ = rest.Redirector(&REST{}) // ResourceLocation returns a URL to which one can send traffic for the specified node. func (r *REST) ResourceLocation(ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { - return node.ResourceLocation(r, r.connection, ctx, id) + return node.ResourceLocation(r, r.connection, r.proxyTransport, ctx, id) } diff --git a/pkg/registry/node/etcd/etcd_test.go b/pkg/registry/node/etcd/etcd_test.go index af23cb0c1da..120735e1f09 100644 --- a/pkg/registry/node/etcd/etcd_test.go +++ b/pkg/registry/node/etcd/etcd_test.go @@ -38,7 +38,7 @@ func (fakeConnectionInfoGetter) GetConnectionInfo(host string) (string, uint, ht func newStorage(t *testing.T) (*REST, *tools.FakeEtcdClient) { etcdStorage, fakeClient := registrytest.NewEtcdStorage(t, "") - storage, _ := NewREST(etcdStorage, false, fakeConnectionInfoGetter{}) + storage, _ := NewREST(etcdStorage, false, fakeConnectionInfoGetter{}, nil) return storage, fakeClient } diff --git a/pkg/registry/node/strategy.go b/pkg/registry/node/strategy.go index 76f77e0239d..679a53b8fa9 100644 --- a/pkg/registry/node/strategy.go +++ b/pkg/registry/node/strategy.go @@ -136,7 +136,7 @@ func MatchNode(label labels.Selector, field fields.Selector) generic.Matcher { } // ResourceLocation returns an URL and transport which one can use to send traffic for the specified node. -func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGetter, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { +func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGetter, proxyTransport http.RoundTripper, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { schemeReq, name, portReq, valid := util.SplitSchemeNamePort(id) if !valid { return nil, nil, errors.NewBadRequest(fmt.Sprintf("invalid node request %q", id)) @@ -155,7 +155,7 @@ func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGet if portReq == "" || strconv.Itoa(ports.KubeletPort) == portReq { // Ignore requested scheme, use scheme provided by GetConnectionInfo - scheme, port, transport, err := connection.GetConnectionInfo(host) + scheme, port, kubeletTransport, err := connection.GetConnectionInfo(host) if err != nil { return nil, nil, err } @@ -166,8 +166,8 @@ func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGet strconv.FormatUint(uint64(port), 10), ), }, - transport, + kubeletTransport, nil } - return &url.URL{Scheme: schemeReq, Host: net.JoinHostPort(host, portReq)}, nil, nil + return &url.URL{Scheme: schemeReq, Host: net.JoinHostPort(host, portReq)}, proxyTransport, nil } diff --git a/pkg/registry/pod/etcd/etcd.go b/pkg/registry/pod/etcd/etcd.go index 5b71cd782e0..9faa40abf62 100644 --- a/pkg/registry/pod/etcd/etcd.go +++ b/pkg/registry/pod/etcd/etcd.go @@ -56,10 +56,11 @@ type PodStorage struct { // REST implements a RESTStorage for pods against etcd type REST struct { *etcdgeneric.Etcd + proxyTransport http.RoundTripper } // NewStorage returns a RESTStorage object that will work against pods. -func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGetter) PodStorage { +func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGetter, proxyTransport http.RoundTripper) PodStorage { prefix := "/pods" storageInterface := s @@ -106,11 +107,11 @@ func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGett statusStore.UpdateStrategy = pod.StatusStrategy return PodStorage{ - Pod: &REST{store}, + Pod: &REST{store, proxyTransport}, Binding: &BindingREST{store: store}, Status: &StatusREST{store: &statusStore}, Log: &LogREST{store: store, kubeletConn: k}, - Proxy: &ProxyREST{store: store}, + Proxy: &ProxyREST{store: store, proxyTransport: proxyTransport}, Exec: &ExecREST{store: store, kubeletConn: k}, Attach: &AttachREST{store: store, kubeletConn: k}, PortForward: &PortForwardREST{store: store, kubeletConn: k}, @@ -122,7 +123,7 @@ var _ = rest.Redirector(&REST{}) // ResourceLocation returns a pods location from its HostIP func (r *REST) ResourceLocation(ctx api.Context, name string) (*url.URL, http.RoundTripper, error) { - return pod.ResourceLocation(r, ctx, name) + return pod.ResourceLocation(r, r.proxyTransport, ctx, name) } // BindingREST implements the REST endpoint for binding pods to nodes when etcd is in use. @@ -256,7 +257,8 @@ func (r *LogREST) NewGetOptions() (runtime.Object, bool, string) { // ProxyREST implements the proxy subresource for a Pod // TODO: move me into pod/rest - I'm generic to store type via ResourceGetter type ProxyREST struct { - store *etcdgeneric.Etcd + store *etcdgeneric.Etcd + proxyTransport http.RoundTripper } // Implement Connecter @@ -285,7 +287,7 @@ func (r *ProxyREST) Connect(ctx api.Context, id string, opts runtime.Object) (re if !ok { return nil, fmt.Errorf("Invalid options object: %#v", opts) } - location, transport, err := pod.ResourceLocation(r.store, ctx, id) + location, transport, err := pod.ResourceLocation(r.store, r.proxyTransport, ctx, id) if err != nil { return nil, err } diff --git a/pkg/registry/pod/etcd/etcd_test.go b/pkg/registry/pod/etcd/etcd_test.go index 0c4fd6f6e36..19cc073954b 100644 --- a/pkg/registry/pod/etcd/etcd_test.go +++ b/pkg/registry/pod/etcd/etcd_test.go @@ -38,7 +38,7 @@ import ( func newStorage(t *testing.T) (*REST, *BindingREST, *StatusREST, *tools.FakeEtcdClient) { etcdStorage, fakeClient := registrytest.NewEtcdStorage(t, "") - storage := NewStorage(etcdStorage, false, nil) + storage := NewStorage(etcdStorage, false, nil, nil) return storage.Pod, storage.Binding, storage.Status, fakeClient } @@ -740,7 +740,7 @@ func TestEtcdUpdateStatus(t *testing.T) { func TestPodLogValidates(t *testing.T) { etcdStorage, _ := registrytest.NewEtcdStorage(t, "") - storage := NewStorage(etcdStorage, false, nil) + storage := NewStorage(etcdStorage, false, nil, nil) negativeOne := int64(-1) testCases := []*api.PodLogOptions{ diff --git a/pkg/registry/pod/strategy.go b/pkg/registry/pod/strategy.go index 4029715b977..82ac5c49b0f 100644 --- a/pkg/registry/pod/strategy.go +++ b/pkg/registry/pod/strategy.go @@ -17,7 +17,6 @@ limitations under the License. package pod import ( - "crypto/tls" "fmt" "net" "net/http" @@ -47,13 +46,6 @@ type podStrategy struct { // objects via the REST API. var Strategy = podStrategy{api.Scheme, api.SimpleNameGenerator} -// PodProxyTransport is used by the API proxy to connect to pods -// Exported to allow overriding TLS options (like adding a client certificate) -var PodProxyTransport = util.SetTransportDefaults(&http.Transport{ - // Turn off hostname verification, because connections are to assigned IPs, not deterministic - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, -}) - // NamespaceScoped is true for pods. func (podStrategy) NamespaceScoped() bool { return true @@ -195,7 +187,7 @@ func getPod(getter ResourceGetter, ctx api.Context, name string) (*api.Pod, erro } // ResourceLocation returns a URL to which one can send traffic for the specified pod. -func ResourceLocation(getter ResourceGetter, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { +func ResourceLocation(getter ResourceGetter, rt http.RoundTripper, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { // Allow ID as "podname" or "podname:port" or "scheme:podname:port". // If port is not specified, try to use the first defined port on the pod. scheme, name, port, valid := util.SplitSchemeNamePort(id) @@ -227,7 +219,7 @@ func ResourceLocation(getter ResourceGetter, ctx api.Context, id string) (*url.U } else { loc.Host = net.JoinHostPort(pod.Status.PodIP, port) } - return loc, PodProxyTransport, nil + return loc, rt, nil } // LogLocation returns the log URL for a pod container. If opts.Container is blank diff --git a/pkg/registry/service/rest.go b/pkg/registry/service/rest.go index 3b17d72babf..3b524a2bc7f 100644 --- a/pkg/registry/service/rest.go +++ b/pkg/registry/service/rest.go @@ -17,7 +17,6 @@ limitations under the License. package service import ( - "crypto/tls" "fmt" "math/rand" "net" @@ -48,23 +47,18 @@ type REST struct { endpoints endpoint.Registry serviceIPs ipallocator.Interface serviceNodePorts portallocator.Interface + proxyTransport http.RoundTripper } -// ServiceProxyTransport is used by the API proxy to connect to services -// Exported to allow overriding TLS options (like adding a client certificate) -var ServiceProxyTransport = util.SetTransportDefaults(&http.Transport{ - // Turn off hostname verification, because connections are to assigned IPs, not deterministic - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, -}) - // NewStorage returns a new REST. func NewStorage(registry Registry, endpoints endpoint.Registry, serviceIPs ipallocator.Interface, - serviceNodePorts portallocator.Interface) *REST { + serviceNodePorts portallocator.Interface, proxyTransport http.RoundTripper) *REST { return &REST{ registry: registry, endpoints: endpoints, serviceIPs: serviceIPs, serviceNodePorts: serviceNodePorts, + proxyTransport: proxyTransport, } } @@ -314,7 +308,7 @@ func (rs *REST) ResourceLocation(ctx api.Context, id string) (*url.URL, http.Rou return &url.URL{ Scheme: svcScheme, Host: net.JoinHostPort(ip, strconv.Itoa(port)), - }, ServiceProxyTransport, nil + }, rs.proxyTransport, nil } } } diff --git a/pkg/registry/service/rest_test.go b/pkg/registry/service/rest_test.go index 288b1d1257c..3fe14916ee7 100644 --- a/pkg/registry/service/rest_test.go +++ b/pkg/registry/service/rest_test.go @@ -46,7 +46,7 @@ func NewTestREST(t *testing.T, endpoints *api.EndpointsList) (*REST, *registryte portRange := util.PortRange{Base: 30000, Size: 1000} portAllocator := portallocator.NewPortAllocator(portRange) - storage := NewStorage(registry, endpointRegistry, r, portAllocator) + storage := NewStorage(registry, endpointRegistry, r, portAllocator, nil) return storage, registry } diff --git a/pkg/util/http.go b/pkg/util/http.go index e8253fa80f5..8f35ce49bbf 100644 --- a/pkg/util/http.go +++ b/pkg/util/http.go @@ -17,7 +17,10 @@ limitations under the License. package util import ( + "crypto/tls" + "fmt" "io" + "net" "net/http" "net/url" "strings" @@ -62,3 +65,40 @@ func SetTransportDefaults(t *http.Transport) *http.Transport { } return t } + +type RoundTripperWrapper interface { + http.RoundTripper + WrappedRoundTripper() http.RoundTripper +} + +type DialFunc func(net, addr string) (net.Conn, error) + +func Dialer(transport http.RoundTripper) (DialFunc, error) { + if transport == nil { + return nil, nil + } + + switch transport := transport.(type) { + case *http.Transport: + return transport.Dial, nil + case RoundTripperWrapper: + return Dialer(transport.WrappedRoundTripper()) + default: + return nil, fmt.Errorf("unknown transport type: %v", transport) + } +} + +func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) { + if transport == nil { + return nil, nil + } + + switch transport := transport.(type) { + case *http.Transport: + return transport.TLSClientConfig, nil + case RoundTripperWrapper: + return TLSClientConfig(transport.WrappedRoundTripper()) + default: + return nil, fmt.Errorf("unknown transport type: %v", transport) + } +} diff --git a/pkg/util/proxy/dial.go b/pkg/util/proxy/dial.go new file mode 100644 index 00000000000..07982b79386 --- /dev/null +++ b/pkg/util/proxy/dial.go @@ -0,0 +1,106 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + + "github.com/golang/glog" + + "k8s.io/kubernetes/pkg/util" + "k8s.io/kubernetes/third_party/golang/netutil" +) + +func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { + dialAddr := netutil.CanonicalAddr(url) + + dialer, _ := util.Dialer(transport) + + switch url.Scheme { + case "http": + if dialer != nil { + return dialer("tcp", dialAddr) + } + return net.Dial("tcp", dialAddr) + case "https": + // Get the tls config from the transport if we recognize it + var tlsConfig *tls.Config + var tlsConn *tls.Conn + var err error + tlsConfig, _ = util.TLSClientConfig(transport) + + if dialer != nil { + // We have a dialer; use it to open the connection, then + // create a tls client using the connection. + netConn, err := dialer("tcp", dialAddr) + if err != nil { + return nil, err + } + if tlsConfig == nil { + // tls.Client requires non-nil config + glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify") + // tls.Handshake() requires ServerName or InsecureSkipVerify + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { + // tls.Handshake() requires ServerName or InsecureSkipVerify + // infer the ServerName from the hostname we're connecting to. + inferredHost := dialAddr + if host, _, err := net.SplitHostPort(dialAddr); err == nil { + inferredHost = host + } + // Make a copy to avoid polluting the provided config + tlsConfigCopy := *tlsConfig + tlsConfigCopy.ServerName = inferredHost + tlsConfig = &tlsConfigCopy + } + tlsConn = tls.Client(netConn, tlsConfig) + if err := tlsConn.Handshake(); err != nil { + netConn.Close() + return nil, err + } + + } else { + // Dial + tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig) + if err != nil { + return nil, err + } + } + + // Return if we were configured to skip validation + if tlsConfig != nil && tlsConfig.InsecureSkipVerify { + return tlsConn, nil + } + + // Verify + host, _, _ := net.SplitHostPort(dialAddr) + if err := tlsConn.VerifyHostname(host); err != nil { + tlsConn.Close() + return nil, err + } + + return tlsConn, nil + default: + return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme) + } +} diff --git a/pkg/util/proxy/transport.go b/pkg/util/proxy/transport.go index 80855f46b8b..b3365784524 100644 --- a/pkg/util/proxy/transport.go +++ b/pkg/util/proxy/transport.go @@ -31,6 +31,7 @@ import ( "golang.org/x/net/html" "golang.org/x/net/html/atom" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/sets" ) @@ -118,6 +119,12 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return t.rewriteResponse(req, resp) } +var _ = util.RoundTripperWrapper(&Transport{}) + +func (rt *Transport) WrappedRoundTripper() http.RoundTripper { + return rt.RoundTripper +} + // rewriteURL rewrites a single URL to go through the proxy, if the URL refers // to the same host as sourceURL, which is the page on which the target URL // occurred. If any error occurs (e.g. parsing), it returns targetURL. diff --git a/test/e2e/kubectl.go b/test/e2e/kubectl.go index 88729a403aa..9ee2916b86d 100644 --- a/test/e2e/kubectl.go +++ b/test/e2e/kubectl.go @@ -669,6 +669,9 @@ var _ = Describe("Kubectl client", func() { By("curling proxy /api/ output") localAddr := fmt.Sprintf("http://localhost:%d/api/", port) apiVersions, err := getAPIVersions(localAddr) + if err != nil { + Failf("Expected at least one supported apiversion, got error %v", err) + } if len(apiVersions.Versions) < 1 { Failf("Expected at least one supported apiversion, got %v", apiVersions) }