From 826459e51e7f7d8c9dd593be735ae684a2c61c5b Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Fri, 9 Oct 2015 02:25:25 -0400 Subject: [PATCH 1/2] Allow specifying scheme when proxying --- pkg/registry/generic/rest/proxy.go | 31 ++++++----- pkg/registry/generic/rest/proxy_test.go | 2 +- pkg/registry/node/strategy.go | 5 +- pkg/registry/pod/etcd/etcd.go | 15 +++--- pkg/registry/pod/strategy.go | 22 +++++--- pkg/registry/service/rest.go | 19 ++++--- pkg/registry/service/rest_test.go | 12 +++++ pkg/util/port_split.go | 59 ++++++++++++++++---- pkg/util/port_split_test.go | 72 ++++++++++++++++++++----- test/e2e/proxy.go | 66 +++++++++++++++++++---- test/images/porter/pod.json | 2 +- 11 files changed, 235 insertions(+), 70 deletions(-) diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index 45611931bd3..14bd0e0e27d 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -41,19 +41,23 @@ import ( type UpgradeAwareProxyHandler struct { UpgradeRequired bool Location *url.URL - Transport http.RoundTripper - FlushInterval time.Duration - MaxBytesPerSec int64 - err error + // Transport provides an optional round tripper to use to proxy. If nil, the default proxy transport is used + Transport http.RoundTripper + // WrapTransport indicates whether the provided Transport should be wrapped with default proxy transport behavior (URL rewriting, X-Forwarded-* header setting) + WrapTransport bool + FlushInterval time.Duration + MaxBytesPerSec int64 + err error } const defaultFlushInterval = 200 * time.Millisecond // NewUpgradeAwareProxyHandler creates a new proxy handler with a default flush interval -func NewUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, upgradeRequired bool) *UpgradeAwareProxyHandler { +func NewUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool) *UpgradeAwareProxyHandler { return &UpgradeAwareProxyHandler{ Location: location, Transport: transport, + WrapTransport: wrapTransport, UpgradeRequired: upgradeRequired, FlushInterval: defaultFlushInterval, } @@ -101,8 +105,8 @@ func (h *UpgradeAwareProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Re return } - if h.Transport == nil { - h.Transport = h.defaultProxyTransport(req.URL) + if h.Transport == nil || h.WrapTransport { + h.Transport = h.defaultProxyTransport(req.URL, h.Transport) } newReq, err := http.NewRequest(req.Method, loc.String(), req.Body) @@ -258,7 +262,7 @@ func (h *UpgradeAwareProxyHandler) dialURL() (net.Conn, error) { } } -func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.RoundTripper { +func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper { scheme := url.Scheme host := url.Host suffix := h.Location.Path @@ -266,13 +270,14 @@ func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL) http.Roun suffix += "/" } pathPrepend := strings.TrimSuffix(url.Path, suffix) - internalTransport := &proxy.Transport{ - Scheme: scheme, - Host: host, - PathPrepend: pathPrepend, + rewritingTransport := &proxy.Transport{ + Scheme: scheme, + Host: host, + PathPrepend: pathPrepend, + RoundTripper: internalTransport, } return &corsRemovingTransport{ - RoundTripper: internalTransport, + RoundTripper: rewritingTransport, } } diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index bb2e8b56152..1f20f3dac87 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -388,7 +388,7 @@ func TestDefaultProxyTransport(t *testing.T) { h := UpgradeAwareProxyHandler{ Location: locURL, } - result := h.defaultProxyTransport(URL) + result := h.defaultProxyTransport(URL, nil) transport := result.(*corsRemovingTransport).RoundTripper.(*proxy.Transport) if transport.Scheme != test.expectedScheme { t.Errorf("%s: unexpected scheme. Actual: %s, Expected: %s", test.name, transport.Scheme, test.expectedScheme) diff --git a/pkg/registry/node/strategy.go b/pkg/registry/node/strategy.go index 939fc106dca..76f77e0239d 100644 --- a/pkg/registry/node/strategy.go +++ b/pkg/registry/node/strategy.go @@ -137,7 +137,7 @@ func MatchNode(label labels.Selector, field fields.Selector) generic.Matcher { // ResourceLocation returns an URL and transport which one can use to send traffic for the specified node. func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGetter, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { - name, portReq, valid := util.SplitPort(id) + schemeReq, name, portReq, valid := util.SplitSchemeNamePort(id) if !valid { return nil, nil, errors.NewBadRequest(fmt.Sprintf("invalid node request %q", id)) } @@ -154,6 +154,7 @@ func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGet host := hostIP.String() if portReq == "" || strconv.Itoa(ports.KubeletPort) == portReq { + // Ignore requested scheme, use scheme provided by GetConnectionInfo scheme, port, transport, err := connection.GetConnectionInfo(host) if err != nil { return nil, nil, err @@ -168,5 +169,5 @@ func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGet transport, nil } - return &url.URL{Host: net.JoinHostPort(host, portReq)}, nil, nil + return &url.URL{Scheme: schemeReq, Host: net.JoinHostPort(host, portReq)}, nil, nil } diff --git a/pkg/registry/pod/etcd/etcd.go b/pkg/registry/pod/etcd/etcd.go index 4779e8aa950..5b71cd782e0 100644 --- a/pkg/registry/pod/etcd/etcd.go +++ b/pkg/registry/pod/etcd/etcd.go @@ -285,12 +285,13 @@ func (r *ProxyREST) Connect(ctx api.Context, id string, opts runtime.Object) (re if !ok { return nil, fmt.Errorf("Invalid options object: %#v", opts) } - location, _, err := pod.ResourceLocation(r.store, ctx, id) + location, transport, err := pod.ResourceLocation(r.store, ctx, id) if err != nil { return nil, err } location.Path = path.Join(location.Path, proxyOpts.Path) - return newUpgradeAwareProxyHandler(location, nil, false), nil + // Return a proxy handler that uses the desired transport, wrapped with additional proxy handling (to get URL rewriting, X-Forwarded-* headers, etc) + return newThrottledUpgradeAwareProxyHandler(location, transport, true, false), nil } // Support both GET and POST methods. We must support GET for browsers that want to use WebSockets. @@ -321,7 +322,7 @@ func (r *AttachREST) Connect(ctx api.Context, name string, opts runtime.Object) if err != nil { return nil, err } - return genericrest.NewUpgradeAwareProxyHandler(location, transport, true), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true), nil } // NewConnectOptions returns the versioned object that represents exec parameters @@ -359,7 +360,7 @@ func (r *ExecREST) Connect(ctx api.Context, name string, opts runtime.Object) (r if err != nil { return nil, err } - return newUpgradeAwareProxyHandler(location, transport, true), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true), nil } // NewConnectOptions returns the versioned object that represents exec parameters @@ -403,11 +404,11 @@ func (r *PortForwardREST) Connect(ctx api.Context, name string, opts runtime.Obj if err != nil { return nil, err } - return newUpgradeAwareProxyHandler(location, transport, true), nil + return newThrottledUpgradeAwareProxyHandler(location, transport, false, true), nil } -func newUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, upgradeRequired bool) *genericrest.UpgradeAwareProxyHandler { - handler := genericrest.NewUpgradeAwareProxyHandler(location, transport, upgradeRequired) +func newThrottledUpgradeAwareProxyHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool) *genericrest.UpgradeAwareProxyHandler { + handler := genericrest.NewUpgradeAwareProxyHandler(location, transport, wrapTransport, upgradeRequired) handler.MaxBytesPerSec = capabilities.Get().PerConnectionBandwidthLimitBytesPerSec return handler } diff --git a/pkg/registry/pod/strategy.go b/pkg/registry/pod/strategy.go index cc721099a52..4029715b977 100644 --- a/pkg/registry/pod/strategy.go +++ b/pkg/registry/pod/strategy.go @@ -17,6 +17,7 @@ limitations under the License. package pod import ( + "crypto/tls" "fmt" "net" "net/http" @@ -46,6 +47,13 @@ type podStrategy struct { // objects via the REST API. var Strategy = podStrategy{api.Scheme, api.SimpleNameGenerator} +// PodProxyTransport is used by the API proxy to connect to pods +// Exported to allow overriding TLS options (like adding a client certificate) +var PodProxyTransport = util.SetTransportDefaults(&http.Transport{ + // Turn off hostname verification, because connections are to assigned IPs, not deterministic + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, +}) + // NamespaceScoped is true for pods. func (podStrategy) NamespaceScoped() bool { return true @@ -188,9 +196,9 @@ func getPod(getter ResourceGetter, ctx api.Context, name string) (*api.Pod, erro // ResourceLocation returns a URL to which one can send traffic for the specified pod. func ResourceLocation(getter ResourceGetter, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { - // Allow ID as "podname" or "podname:port". If port is not specified, - // try to use the first defined port on the pod. - name, port, valid := util.SplitPort(id) + // Allow ID as "podname" or "podname:port" or "scheme:podname:port". + // If port is not specified, try to use the first defined port on the pod. + scheme, name, port, valid := util.SplitSchemeNamePort(id) if !valid { return nil, nil, errors.NewBadRequest(fmt.Sprintf("invalid pod request %q", id)) } @@ -211,15 +219,15 @@ func ResourceLocation(getter ResourceGetter, ctx api.Context, id string) (*url.U } } - // We leave off the scheme ('http://') because we have no idea what sort of server - // is listening at this endpoint. - loc := &url.URL{} + loc := &url.URL{ + Scheme: scheme, + } if port == "" { loc.Host = pod.Status.PodIP } else { loc.Host = net.JoinHostPort(pod.Status.PodIP, port) } - return loc, nil, nil + return loc, PodProxyTransport, nil } // LogLocation returns the log URL for a pod container. If opts.Container is blank diff --git a/pkg/registry/service/rest.go b/pkg/registry/service/rest.go index 075d3d84ac7..3b17d72babf 100644 --- a/pkg/registry/service/rest.go +++ b/pkg/registry/service/rest.go @@ -17,6 +17,7 @@ limitations under the License. package service import ( + "crypto/tls" "fmt" "math/rand" "net" @@ -49,6 +50,13 @@ type REST struct { serviceNodePorts portallocator.Interface } +// ServiceProxyTransport is used by the API proxy to connect to services +// Exported to allow overriding TLS options (like adding a client certificate) +var ServiceProxyTransport = util.SetTransportDefaults(&http.Transport{ + // Turn off hostname verification, because connections are to assigned IPs, not deterministic + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, +}) + // NewStorage returns a new REST. func NewStorage(registry Registry, endpoints endpoint.Registry, serviceIPs ipallocator.Interface, serviceNodePorts portallocator.Interface) *REST { @@ -277,8 +285,8 @@ var _ = rest.Redirector(&REST{}) // ResourceLocation returns a URL to which one can send traffic for the specified service. func (rs *REST) ResourceLocation(ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { - // Allow ID as "svcname" or "svcname:port". - svcName, portStr, valid := util.SplitPort(id) + // Allow ID as "svcname", "svcname:port", or "scheme:svcname:port". + svcScheme, svcName, portStr, valid := util.SplitSchemeNamePort(id) if !valid { return nil, nil, errors.NewBadRequest(fmt.Sprintf("invalid service request %q", id)) } @@ -303,11 +311,10 @@ func (rs *REST) ResourceLocation(ctx api.Context, id string) (*url.URL, http.Rou // Pick a random address. ip := ss.Addresses[rand.Intn(len(ss.Addresses))].IP port := ss.Ports[i].Port - // We leave off the scheme ('http://') because we have no idea what sort of server - // is listening at this endpoint. return &url.URL{ - Host: net.JoinHostPort(ip, strconv.Itoa(port)), - }, nil, nil + Scheme: svcScheme, + Host: net.JoinHostPort(ip, strconv.Itoa(port)), + }, ServiceProxyTransport, nil } } } diff --git a/pkg/registry/service/rest_test.go b/pkg/registry/service/rest_test.go index 2f7e76e1619..288b1d1257c 100644 --- a/pkg/registry/service/rest_test.go +++ b/pkg/registry/service/rest_test.go @@ -491,6 +491,18 @@ func TestServiceRegistryResourceLocation(t *testing.T) { t.Errorf("Expected %v, but got %v", e, a) } + // Test a scheme + name + port. + location, _, err = redirector.ResourceLocation(ctx, "https:foo:p") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if location == nil { + t.Errorf("Unexpected nil: %v", location) + } + if e, a := "https://1.2.3.4:93", location.String(); e != a { + t.Errorf("Expected %v, but got %v", e, a) + } + // Test a non-existent name + port. location, _, err = redirector.ResourceLocation(ctx, "foo:q") if err == nil { diff --git a/pkg/util/port_split.go b/pkg/util/port_split.go index c991dc67588..176271189a6 100644 --- a/pkg/util/port_split.go +++ b/pkg/util/port_split.go @@ -18,23 +18,60 @@ package util import ( "strings" + + "k8s.io/kubernetes/pkg/util/sets" ) -// Takes a string of the form "name:port" or "name". -// * If id is of the form "name" or "name:", then return (name, "", true) -// * If id is of the form "name:port", then return (name, port, true) -// * Otherwise, return ("", "", false) -// Additionally, name must be non-empty or valid will be returned false. +var validSchemes = sets.NewString("http", "https", "") + +// SplitSchemeNamePort takes a string of the following forms: +// * "", returns "", "","", true +// * ":", returns "", "","",true +// * "::", returns "","","",true // +// Name must be non-empty or valid will be returned false. +// Scheme must be "http" or "https" if specified // Port is returned as a string, and it is not required to be numeric (could be // used for a named port, for example). -func SplitPort(id string) (name, port string, valid bool) { +func SplitSchemeNamePort(id string) (scheme, name, port string, valid bool) { parts := strings.Split(id, ":") - if len(parts) > 2 { - return "", "", false + switch len(parts) { + case 1: + name = parts[0] + case 2: + name = parts[0] + port = parts[1] + case 3: + scheme = parts[0] + name = parts[1] + port = parts[2] + default: + return "", "", "", false } - if len(parts) == 2 { - return parts[0], parts[1], len(parts[0]) > 0 + + if len(name) > 0 && validSchemes.Has(scheme) { + return scheme, name, port, true + } else { + return "", "", "", false } - return id, "", len(id) > 0 +} + +// JoinSchemeNamePort returns a string that specifies the scheme, name, and port: +// * "" +// * ":" +// * "::" +// None of the parameters may contain a ':' character +// Name is required +// Scheme must be "", "http", or "https" +func JoinSchemeNamePort(scheme, name, port string) string { + if len(scheme) > 0 { + // Must include three segments to specify scheme + return scheme + ":" + name + ":" + port + } + if len(port) > 0 { + // Must include two segments to specify port + return name + ":" + port + } + // Return name alone + return name } diff --git a/pkg/util/port_split_test.go b/pkg/util/port_split_test.go index 468059cb5d7..9d9e5fb0ff2 100644 --- a/pkg/util/port_split_test.go +++ b/pkg/util/port_split_test.go @@ -20,11 +20,12 @@ import ( "testing" ) -func TestSplitPort(t *testing.T) { +func TestSplitSchemeNamePort(t *testing.T) { table := []struct { - in string - name, port string - valid bool + in string + name, port, scheme string + valid bool + normalized bool }{ { in: "aoeu:asdf", @@ -32,26 +33,50 @@ func TestSplitPort(t *testing.T) { port: "asdf", valid: true, }, { - in: "aoeu:", - name: "aoeu", - valid: true, + in: "http:aoeu:asdf", + scheme: "http", + name: "aoeu", + port: "asdf", + valid: true, }, { - in: ":asdf", - name: "", - port: "asdf", + in: "https:aoeu:", + scheme: "https", + name: "aoeu", + port: "", + valid: true, + normalized: false, }, { - in: "aoeu:asdf:htns", + in: "https:aoeu:asdf", + scheme: "https", + name: "aoeu", + port: "asdf", + valid: true, + }, { + in: "aoeu:", + name: "aoeu", + valid: true, + normalized: false, + }, { + in: ":asdf", + valid: false, + }, { + in: "aoeu:asdf:htns", + valid: false, }, { in: "aoeu", name: "aoeu", valid: true, }, { - in: "", + in: "", + valid: false, }, } for _, item := range table { - name, port, valid := SplitPort(item.in) + scheme, name, port, valid := SplitSchemeNamePort(item.in) + if e, a := item.scheme, scheme; e != a { + t.Errorf("%q: Wanted %q, got %q", item.in, e, a) + } if e, a := item.name, name; e != a { t.Errorf("%q: Wanted %q, got %q", item.in, e, a) } @@ -61,5 +86,26 @@ func TestSplitPort(t *testing.T) { if e, a := item.valid, valid; e != a { t.Errorf("%q: Wanted %t, got %t", item.in, e, a) } + + // Make sure valid items round trip through JoinSchemeNamePort + if item.valid { + out := JoinSchemeNamePort(scheme, name, port) + if item.normalized && out != item.in { + t.Errorf("%q: Wanted %s, got %s", item.in, item.in, out) + } + scheme, name, port, valid := SplitSchemeNamePort(out) + if e, a := item.scheme, scheme; e != a { + t.Errorf("%q: Wanted %q, got %q", item.in, e, a) + } + if e, a := item.name, name; e != a { + t.Errorf("%q: Wanted %q, got %q", item.in, e, a) + } + if e, a := item.port, port; e != a { + t.Errorf("%q: Wanted %q, got %q", item.in, e, a) + } + if e, a := item.valid, valid; e != a { + t.Errorf("%q: Wanted %t, got %t", item.in, e, a) + } + } } } diff --git a/test/e2e/proxy.go b/test/e2e/proxy.go index 41fef0c6c39..69230b7b26b 100644 --- a/test/e2e/proxy.go +++ b/test/e2e/proxy.go @@ -78,6 +78,16 @@ func proxyContext(version string) { Port: 81, TargetPort: util.NewIntOrStringFromInt(162), }, + { + Name: "tlsportname1", + Port: 443, + TargetPort: util.NewIntOrStringFromString("tlsdest1"), + }, + { + Name: "tlsportname2", + Port: 444, + TargetPort: util.NewIntOrStringFromInt(462), + }, }, }, }) @@ -93,7 +103,7 @@ func proxyContext(version string) { pods := []*api.Pod{} cfg := RCConfig{ Client: f.Client, - Image: "gcr.io/google_containers/porter:59ad46ed2c56ba50fa7f1dc176c07c37", + Image: "gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab", Name: service.Name, Namespace: f.Namespace.Name, Replicas: 1, @@ -102,10 +112,17 @@ func proxyContext(version string) { "SERVE_PORT_80": `test`, "SERVE_PORT_160": "foo", "SERVE_PORT_162": "bar", + + "SERVE_TLS_PORT_443": `test`, + "SERVE_TLS_PORT_460": `tls baz`, + "SERVE_TLS_PORT_462": `tls qux`, }, Ports: map[string]int{ "dest1": 160, "dest2": 162, + + "tlsdest1": 460, + "tlsdest2": 462, }, Labels: labels, CreatedPods: &pods, @@ -116,14 +133,44 @@ func proxyContext(version string) { Expect(f.WaitForAnEndpoint(service.Name)).NotTo(HaveOccurred()) // Try proxying through the service and directly to through the pod. - svcPrefix := prefix + "/proxy/namespaces/" + f.Namespace.Name + "/services/" + service.Name - podPrefix := prefix + "/proxy/namespaces/" + f.Namespace.Name + "/pods/" + pods[0].Name + svcProxyURL := func(scheme, port string) string { + return prefix + "/proxy/namespaces/" + f.Namespace.Name + "/services/" + util.JoinSchemeNamePort(scheme, service.Name, port) + } + podProxyURL := func(scheme, port string) string { + return prefix + "/proxy/namespaces/" + f.Namespace.Name + "/pods/" + util.JoinSchemeNamePort(scheme, pods[0].Name, port) + } + subresourcePodProxyURL := func(scheme, port string) string { + return prefix + "/namespaces/" + f.Namespace.Name + "/pods/" + util.JoinSchemeNamePort(scheme, pods[0].Name, port) + "/proxy" + } expectations := map[string]string{ - svcPrefix + ":portname1/": "foo", - svcPrefix + ":portname2/": "bar", - podPrefix + ":80/": `test`, - podPrefix + ":160/": "foo", - podPrefix + ":162/": "bar", + svcProxyURL("", "portname1") + "/": "foo", + svcProxyURL("", "portname2") + "/": "bar", + + svcProxyURL("http", "portname1") + "/": "foo", + svcProxyURL("http", "portname2") + "/": "bar", + + svcProxyURL("https", "tlsportname1") + "/": "tls baz", + svcProxyURL("https", "tlsportname2") + "/": "tls qux", + + podProxyURL("", "80") + "/": `test`, + podProxyURL("", "160") + "/": "foo", + podProxyURL("", "162") + "/": "bar", + + podProxyURL("http", "80") + "/": `test`, + podProxyURL("http", "160") + "/": "foo", + podProxyURL("http", "162") + "/": "bar", + + subresourcePodProxyURL("", "") + "/": `test`, + subresourcePodProxyURL("", "80") + "/": `test`, + subresourcePodProxyURL("http", "80") + "/": `test`, + subresourcePodProxyURL("", "160") + "/": "foo", + subresourcePodProxyURL("http", "160") + "/": "foo", + subresourcePodProxyURL("", "162") + "/": "bar", + subresourcePodProxyURL("http", "162") + "/": "bar", + + subresourcePodProxyURL("https", "443") + "/": `test`, + subresourcePodProxyURL("https", "460") + "/": "tls baz", + subresourcePodProxyURL("https", "462") + "/": "tls qux", // TODO: below entries don't work, but I believe we should make them work. // svcPrefix + ":80": "foo", // svcPrefix + ":81": "bar", @@ -159,7 +206,8 @@ func proxyContext(version string) { recordError(fmt.Sprintf("%v: path %v took %v > 15s", i, path, d)) } }(i, path, val) - time.Sleep(150 * time.Millisecond) + // default QPS is 5 + time.Sleep(200 * time.Millisecond) } } wg.Wait() diff --git a/test/images/porter/pod.json b/test/images/porter/pod.json index 894fc76a9a8..2a94ee1e7e9 100644 --- a/test/images/porter/pod.json +++ b/test/images/porter/pod.json @@ -8,7 +8,7 @@ "containers": [ { "name": "porter", - "image": "gcr.io/google_containers/porter:59ad46ed2c56ba50fa7f1dc176c07c37", + "image": "gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab", "env": [ { "name": "SERVE_PORT_80", From 1043126135231cef469f7055f561869d0edaf6e6 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Fri, 9 Oct 2015 01:18:16 -0400 Subject: [PATCH 2/2] Refactor SSH tunneling, fix proxy transport TLS/Dial extraction --- cmd/kube-apiserver/app/server.go | 37 +++- pkg/apiserver/api_installer.go | 3 +- pkg/apiserver/apiserver.go | 2 - pkg/apiserver/proxy.go | 55 +---- pkg/apiserver/proxy_test.go | 16 ++ pkg/client/chaosclient/chaosclient.go | 8 + pkg/client/unversioned/debugging.go | 7 + pkg/client/unversioned/transport.go | 20 ++ pkg/master/master.go | 253 +++-------------------- pkg/master/master_test.go | 128 ++---------- pkg/master/tunneler.go | 262 ++++++++++++++++++++++++ pkg/master/tunneler_test.go | 139 +++++++++++++ pkg/registry/generic/rest/proxy.go | 85 +------- pkg/registry/generic/rest/proxy_test.go | 16 ++ pkg/registry/node/etcd/etcd.go | 9 +- pkg/registry/node/etcd/etcd_test.go | 2 +- pkg/registry/node/strategy.go | 8 +- pkg/registry/pod/etcd/etcd.go | 14 +- pkg/registry/pod/etcd/etcd_test.go | 4 +- pkg/registry/pod/strategy.go | 12 +- pkg/registry/service/rest.go | 14 +- pkg/registry/service/rest_test.go | 2 +- pkg/util/http.go | 40 ++++ pkg/util/proxy/dial.go | 106 ++++++++++ pkg/util/proxy/transport.go | 7 + test/e2e/kubectl.go | 3 + 26 files changed, 739 insertions(+), 513 deletions(-) create mode 100644 pkg/master/tunneler.go create mode 100644 pkg/master/tunneler_test.go create mode 100644 pkg/util/proxy/dial.go diff --git a/cmd/kube-apiserver/app/server.go b/cmd/kube-apiserver/app/server.go index 6419497c053..19245c45cd8 100644 --- a/cmd/kube-apiserver/app/server.go +++ b/cmd/kube-apiserver/app/server.go @@ -376,6 +376,30 @@ func (s *APIServer) Run(_ []string) error { glog.Fatalf("Cloud provider could not be initialized: %v", err) } + // Setup tunneler if needed + var tunneler master.Tunneler + var proxyDialerFn apiserver.ProxyDialerFunc + if len(s.SSHUser) > 0 { + // Get ssh key distribution func, if supported + var installSSH master.InstallSSHKey + if cloud != nil { + if instances, supported := cloud.Instances(); supported { + installSSH = instances.AddSSHKeyToAllInstances + } + } + + // Set up the tunneler + tunneler = master.NewSSHTunneler(s.SSHUser, s.SSHKeyfile, installSSH) + + // Use the tunneler's dialer to connect to the kubelet + s.KubeletConfig.Dial = tunneler.Dial + // Use the tunneler's dialer when proxying to pods, services, and nodes + proxyDialerFn = tunneler.Dial + } + + // Proxying to pods and services is IP-based... don't expect to be able to verify the hostname + proxyTLSClientConfig := &tls.Config{InsecureSkipVerify: true} + kubeletClient, err := client.NewKubeletClient(&s.KubeletConfig) if err != nil { glog.Fatalf("Failure to start kubelet client: %v", err) @@ -508,12 +532,7 @@ func (s *APIServer) Run(_ []string) error { } } } - var installSSH master.InstallSSHKey - if cloud != nil { - if instances, supported := cloud.Instances(); supported { - installSSH = instances.AddSSHKeyToAllInstances - } - } + config := &master.Config{ StorageDestinations: storageDestinations, StorageVersions: storageVersions, @@ -542,9 +561,9 @@ func (s *APIServer) Run(_ []string) error { ClusterName: s.ClusterName, ExternalHost: s.ExternalHost, MinRequestTimeout: s.MinRequestTimeout, - SSHUser: s.SSHUser, - SSHKeyfile: s.SSHKeyfile, - InstallSSHKey: installSSH, + ProxyDialer: proxyDialerFn, + ProxyTLSClientConfig: proxyTLSClientConfig, + Tunneler: tunneler, ServiceNodePortRange: s.ServiceNodePortRange, KubernetesServiceNodePort: s.KubernetesServiceNodePort, } diff --git a/pkg/apiserver/api_installer.go b/pkg/apiserver/api_installer.go index 54a2c188f13..bd35e8b32f5 100644 --- a/pkg/apiserver/api_installer.go +++ b/pkg/apiserver/api_installer.go @@ -41,7 +41,6 @@ type APIInstaller struct { info *APIRequestInfoResolver prefix string // Path prefix where API resources are to be registered. minRequestTimeout time.Duration - proxyDialerFn ProxyDialerFunc } // Struct capturing information about an action ("GET", "POST", "WATCH", PROXY", etc). @@ -64,7 +63,7 @@ var errEmptyName = errors.NewBadRequest("name must be provided") func (a *APIInstaller) Install(ws *restful.WebService) (apiResources []api.APIResource, errors []error) { errors = make([]error, 0) - proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.Storage, a.group.Codec, a.group.Context, a.info, a.proxyDialerFn}) + proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.Storage, a.group.Codec, a.group.Context, a.info}) // Register the paths in a deterministic (sorted) order to get a deterministic swagger spec. paths := make([]string, len(a.group.Storage)) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 4314e0d1320..3e6daf34d61 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -105,7 +105,6 @@ type APIGroupVersion struct { Admit admission.Interface Context api.RequestContextMapper - ProxyDialerFn ProxyDialerFunc MinRequestTimeout time.Duration } @@ -164,7 +163,6 @@ func (g *APIGroupVersion) newInstaller() *APIInstaller { info: g.APIRequestInfoResolver, prefix: prefix, minRequestTimeout: g.MinRequestTimeout, - proxyDialerFn: g.ProxyDialerFn, } return installer } diff --git a/pkg/apiserver/proxy.go b/pkg/apiserver/proxy.go index be33a05c588..3398ab71ff3 100644 --- a/pkg/apiserver/proxy.go +++ b/pkg/apiserver/proxy.go @@ -17,11 +17,8 @@ limitations under the License. package apiserver import ( - "crypto/tls" - "fmt" "io" "math/rand" - "net" "net/http" "net/http/httputil" "net/url" @@ -40,7 +37,6 @@ import ( proxyutil "k8s.io/kubernetes/pkg/util/proxy" "github.com/golang/glog" - "k8s.io/kubernetes/third_party/golang/netutil" ) // ProxyHandler provides a http.Handler which will proxy traffic to locations @@ -51,8 +47,6 @@ type ProxyHandler struct { codec runtime.Codec context api.RequestContextMapper apiRequestInfoResolver *APIRequestInfoResolver - - dial func(network, addr string) (net.Conn, error) } func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -125,11 +119,8 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { httpCode = http.StatusNotFound return } - // If we have a custom dialer, and no pre-existing transport, initialize it to use the dialer. - if roundTripper == nil && r.dial != nil { - glog.V(5).Infof("[%x: %v] making a dial-only transport...", proxyHandlerTraceID, req.URL) - roundTripper = &http.Transport{Dial: r.dial} - } else if roundTripper != nil { + + if roundTripper != nil { glog.V(5).Infof("[%x: %v] using transport %T...", proxyHandlerTraceID, req.URL, roundTripper) } @@ -217,7 +208,7 @@ func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Reque if !httpstream.IsUpgradeRequest(req) { return false } - backendConn, err := dialURL(location, transport) + backendConn, err := proxyutil.DialURL(location, transport) if err != nil { status := errToAPIStatus(err) writeJSON(status.Code, r.codec, status, w, true) @@ -264,46 +255,6 @@ func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Reque return true } -func dialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { - dialAddr := netutil.CanonicalAddr(url) - - switch url.Scheme { - case "http": - return net.Dial("tcp", dialAddr) - case "https": - // Get the tls config from the transport if we recognize it - var tlsConfig *tls.Config - if transport != nil { - httpTransport, ok := transport.(*http.Transport) - if ok { - tlsConfig = httpTransport.TLSClientConfig - } - } - - // Dial - tlsConn, err := tls.Dial("tcp", dialAddr, tlsConfig) - if err != nil { - return nil, err - } - - // Return if we were configured to skip validation - if tlsConfig != nil && tlsConfig.InsecureSkipVerify { - return tlsConn, nil - } - - // Verify - host, _, _ := net.SplitHostPort(dialAddr) - if err := tlsConn.VerifyHostname(host); err != nil { - tlsConn.Close() - return nil, err - } - - return tlsConn, nil - default: - return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme) - } -} - // borrowed from net/http/httputil/reverseproxy.go func singleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") diff --git a/pkg/apiserver/proxy_test.go b/pkg/apiserver/proxy_test.go index ddbd6c063c9..3980181e2b4 100644 --- a/pkg/apiserver/proxy_test.go +++ b/pkg/apiserver/proxy_test.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -171,6 +172,21 @@ func TestProxyUpgrade(t *testing.T) { }, ProxyTransport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: localhostPool}}, }, + "https (valid hostname + RootCAs + custom dialer)": { + ServerFunc: func(h http.Handler) *httptest.Server { + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Errorf("https (valid hostname): proxy_test: %v", err) + } + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + ts.StartTLS() + return ts + }, + ProxyTransport: &http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}, + }, } for k, tc := range testcases { diff --git a/pkg/client/chaosclient/chaosclient.go b/pkg/client/chaosclient/chaosclient.go index ec50ce2f625..b74ee80a344 100644 --- a/pkg/client/chaosclient/chaosclient.go +++ b/pkg/client/chaosclient/chaosclient.go @@ -28,6 +28,8 @@ import ( "net/http" "reflect" "runtime" + + "k8s.io/kubernetes/pkg/util" ) // chaosrt provides the ability to perform simulations of HTTP client failures @@ -86,6 +88,12 @@ func (rt *chaosrt) RoundTrip(req *http.Request) (*http.Response, error) { return rt.rt.RoundTrip(req) } +var _ = util.RoundTripperWrapper(&chaosrt{}) + +func (rt *chaosrt) WrappedRoundTripper() http.RoundTripper { + return rt.rt +} + // Seed represents a consistent stream of chaos. type Seed struct { *rand.Rand diff --git a/pkg/client/unversioned/debugging.go b/pkg/client/unversioned/debugging.go index df43e8984d0..76f8c2caa58 100644 --- a/pkg/client/unversioned/debugging.go +++ b/pkg/client/unversioned/debugging.go @@ -23,6 +23,7 @@ import ( "github.com/golang/glog" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/sets" ) @@ -133,3 +134,9 @@ func (rt *DebuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e return response, err } + +var _ = util.RoundTripperWrapper(&DebuggingRoundTripper{}) + +func (rt *DebuggingRoundTripper) WrappedRoundTripper() http.RoundTripper { + return rt.delegatedRoundTripper +} diff --git a/pkg/client/unversioned/transport.go b/pkg/client/unversioned/transport.go index ecb73dc1e7a..f31b1a47e12 100644 --- a/pkg/client/unversioned/transport.go +++ b/pkg/client/unversioned/transport.go @@ -22,6 +22,8 @@ import ( "fmt" "io/ioutil" "net/http" + + "k8s.io/kubernetes/pkg/util" ) type userAgentRoundTripper struct { @@ -42,6 +44,12 @@ func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, e return rt.rt.RoundTrip(req) } +var _ = util.RoundTripperWrapper(&userAgentRoundTripper{}) + +func (rt *userAgentRoundTripper) WrappedRoundTripper() http.RoundTripper { + return rt.rt +} + type basicAuthRoundTripper struct { username string password string @@ -63,6 +71,12 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e return rt.rt.RoundTrip(req) } +var _ = util.RoundTripperWrapper(&basicAuthRoundTripper{}) + +func (rt *basicAuthRoundTripper) WrappedRoundTripper() http.RoundTripper { + return rt.rt +} + type bearerAuthRoundTripper struct { bearer string rt http.RoundTripper @@ -84,6 +98,12 @@ func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, return rt.rt.RoundTrip(req) } +var _ = util.RoundTripperWrapper(&bearerAuthRoundTripper{}) + +func (rt *bearerAuthRoundTripper) WrappedRoundTripper() http.RoundTripper { + return rt.rt +} + // TLSConfigFor returns a tls.Config that will provide the transport level security defined // by the provided Config. Will return nil if no transport level security is requested. func TLSConfigFor(config *Config) (*tls.Config, error) { diff --git a/pkg/master/master.go b/pkg/master/master.go index 439f04609b9..b4ef96b2f5b 100644 --- a/pkg/master/master.go +++ b/pkg/master/master.go @@ -17,18 +17,15 @@ limitations under the License. package master import ( + "crypto/tls" "fmt" - "io/ioutil" - "math/rand" "net" "net/http" "net/http/pprof" "net/url" - "os" "strconv" "strings" "sync" - "sync/atomic" "time" "k8s.io/kubernetes/pkg/admission" @@ -240,10 +237,12 @@ type Config struct { // The range of ports to be assigned to services with type=NodePort or greater ServiceNodePortRange util.PortRange - // Used for secure proxy. If empty, don't use secure proxy. - SSHUser string - SSHKeyfile string - InstallSSHKey InstallSSHKey + // Used to customize default proxy dial/tls options + ProxyDialer apiserver.ProxyDialerFunc + ProxyTLSClientConfig *tls.Config + + // Used to start and monitor tunneling + Tunneler Tunneler KubernetesServiceNodePort int } @@ -305,14 +304,11 @@ type Master struct { Handler http.Handler InsecureHandler http.Handler - // Used for secure proxy - dialer apiserver.ProxyDialerFunc - tunnels *util.SSHTunnelList - tunnelsLock sync.Mutex - installSSHKey InstallSSHKey - lastSync int64 // Seconds since Epoch - lastSyncMetric prometheus.GaugeFunc - clock util.Clock + // Used for custom proxy dialing, and proxy TLS options + proxyTransport http.RoundTripper + + // Used to start and monitor tunneling + tunneler Tunneler // storage for third party objects thirdPartyStorage storage.Interface @@ -453,7 +449,8 @@ func New(c *Config) *Master { // TODO: serviceReadWritePort should be passed in as an argument, it may not always be 443 serviceReadWritePort: 443, - installSSHKey: c.InstallSSHKey, + tunneler: c.Tunneler, + KubernetesServiceNodePort: c.KubernetesServiceNodePort, } @@ -505,10 +502,18 @@ func NewHandlerContainer(mux *http.ServeMux) *restful.Container { // init initializes master. func (m *Master) init(c *Config) { + + if c.ProxyDialer != nil || c.ProxyTLSClientConfig != nil { + m.proxyTransport = util.SetTransportDefaults(&http.Transport{ + Dial: c.ProxyDialer, + TLSClientConfig: c.ProxyTLSClientConfig, + }) + } + healthzChecks := []healthz.HealthzChecker{} - m.clock = util.RealClock{} + dbClient := func(resource string) storage.Interface { return c.StorageDestinations.get("", resource) } - podStorage := podetcd.NewStorage(dbClient("pods"), c.EnableWatchCache, c.KubeletClient) + podStorage := podetcd.NewStorage(dbClient("pods"), c.EnableWatchCache, c.KubeletClient, m.proxyTransport) podTemplateStorage := podtemplateetcd.NewREST(dbClient("podTemplates")) @@ -527,7 +532,7 @@ func (m *Master) init(c *Config) { endpointsStorage := endpointsetcd.NewREST(dbClient("endpoints"), c.EnableWatchCache) m.endpointRegistry = endpoint.NewRegistry(endpointsStorage) - nodeStorage, nodeStatusStorage := nodeetcd.NewREST(dbClient("nodes"), c.EnableWatchCache, c.KubeletClient) + nodeStorage, nodeStatusStorage := nodeetcd.NewREST(dbClient("nodes"), c.EnableWatchCache, c.KubeletClient, m.proxyTransport) m.nodeRegistry = node.NewRegistry(nodeStorage) serviceStorage := serviceetcd.NewREST(dbClient("services")) @@ -569,7 +574,7 @@ func (m *Master) init(c *Config) { "replicationControllers": controllerStorage, "replicationControllers/status": controllerStatusStorage, - "services": service.NewStorage(m.serviceRegistry, m.endpointRegistry, serviceClusterIPAllocator, serviceNodePortAllocator), + "services": service.NewStorage(m.serviceRegistry, m.endpointRegistry, serviceClusterIPAllocator, serviceNodePortAllocator, m.proxyTransport), "endpoints": endpointsStorage, "nodes": nodeStorage, "nodes/status": nodeStatusStorage, @@ -591,51 +596,13 @@ func (m *Master) init(c *Config) { "componentStatuses": componentstatus.NewStorage(func() map[string]apiserver.Server { return m.getServersToValidate(c) }), } - // establish the node proxy dialer - if len(c.SSHUser) > 0 { - // Usernames are capped @ 32 - if len(c.SSHUser) > 32 { - glog.Warning("SSH User is too long, truncating to 32 chars") - c.SSHUser = c.SSHUser[0:32] - } - glog.Infof("Setting up proxy: %s %s", c.SSHUser, c.SSHKeyfile) - - // public keyfile is written last, so check for that. - publicKeyFile := c.SSHKeyfile + ".pub" - exists, err := util.FileExists(publicKeyFile) - if err != nil { - glog.Errorf("Error detecting if key exists: %v", err) - } else if !exists { - glog.Infof("Key doesn't exist, attempting to create") - err := m.generateSSHKey(c.SSHUser, c.SSHKeyfile, publicKeyFile) - if err != nil { - glog.Errorf("Failed to create key pair: %v", err) - } - } - m.tunnels = &util.SSHTunnelList{} - m.dialer = m.Dial - m.setupSecureProxy(c.SSHUser, c.SSHKeyfile, publicKeyFile) - m.lastSync = m.clock.Now().Unix() - - // This is pretty ugly. A better solution would be to pull this all the way up into the - // server.go file. - httpKubeletClient, ok := c.KubeletClient.(*client.HTTPKubeletClient) - if ok { - httpKubeletClient.Config.Dial = m.dialer - transport, err := client.MakeTransport(httpKubeletClient.Config) - if err != nil { - glog.Errorf("Error setting up transport over SSH: %v", err) - } else { - httpKubeletClient.Client.Transport = transport - } - } else { - glog.Errorf("Failed to cast %v to HTTPKubeletClient, skipping SSH tunnel.", c.KubeletClient) - } + if m.tunneler != nil { + m.tunneler.Run(m.getNodeAddresses) healthzChecks = append(healthzChecks, healthz.NamedCheck("SSH Tunnel Check", m.IsTunnelSyncHealthy)) - m.lastSyncMetric = prometheus.NewGaugeFunc(prometheus.GaugeOpts{ + prometheus.NewGaugeFunc(prometheus.GaugeOpts{ Name: "apiserver_proxy_tunnel_sync_latency_secs", Help: "The time since the last successful synchronization of the SSH tunnels for proxy requests.", - }, func() float64 { return float64(m.secondsSinceSync()) }) + }, func() float64 { return float64(m.tunneler.SecondsSinceSync()) }) } apiVersions := []string{} @@ -875,7 +842,6 @@ func (m *Master) defaultAPIGroupVersion() *apiserver.APIGroupVersion { Admit: m.admissionControl, Context: m.requestContextMapper, - ProxyDialerFn: m.dialer, MinRequestTimeout: m.minRequestTimeout, } } @@ -1031,7 +997,6 @@ func (m *Master) thirdpartyapi(group, kind, version string) *apiserver.APIGroupV Context: m.requestContextMapper, - ProxyDialerFn: m.dialer, MinRequestTimeout: m.minRequestTimeout, } } @@ -1094,7 +1059,6 @@ func (m *Master) experimental(c *Config) *apiserver.APIGroupVersion { Admit: m.admissionControl, Context: m.requestContextMapper, - ProxyDialerFn: m.dialer, MinRequestTimeout: m.minRequestTimeout, } } @@ -1117,41 +1081,6 @@ func findExternalAddress(node *api.Node) (string, error) { return "", fmt.Errorf("Couldn't find external address: %v", node) } -func (m *Master) Dial(net, addr string) (net.Conn, error) { - // Only lock while picking a tunnel. - tunnel, err := func() (util.SSHTunnelEntry, error) { - m.tunnelsLock.Lock() - defer m.tunnelsLock.Unlock() - return m.tunnels.PickRandomTunnel() - }() - if err != nil { - return nil, err - } - - start := time.Now() - id := rand.Int63() // So you can match begins/ends in the log. - glog.V(3).Infof("[%x: %v] Dialing...", id, tunnel.Address) - defer func() { - glog.V(3).Infof("[%x: %v] Dialed in %v.", id, tunnel.Address, time.Now().Sub(start)) - }() - return tunnel.Tunnel.Dial(net, addr) -} - -func (m *Master) needToReplaceTunnels(addrs []string) bool { - m.tunnelsLock.Lock() - defer m.tunnelsLock.Unlock() - if m.tunnels == nil || m.tunnels.Len() != len(addrs) { - return true - } - // TODO (cjcullen): This doesn't need to be n^2 - for ix := range addrs { - if !m.tunnels.Has(addrs[ix]) { - return true - } - } - return false -} - func (m *Master) getNodeAddresses() ([]string, error) { nodes, err := m.nodeRegistry.ListNodes(api.NewDefaultContext(), labels.Everything(), fields.Everything()) if err != nil { @@ -1170,126 +1099,12 @@ func (m *Master) getNodeAddresses() ([]string, error) { } func (m *Master) IsTunnelSyncHealthy(req *http.Request) error { - lag := m.secondsSinceSync() + if m.tunneler == nil { + return nil + } + lag := m.tunneler.SecondsSinceSync() if lag > 600 { return fmt.Errorf("Tunnel sync is taking to long: %d", lag) } return nil } - -func (m *Master) secondsSinceSync() int64 { - now := m.clock.Now().Unix() - then := atomic.LoadInt64(&m.lastSync) - return now - then -} - -func (m *Master) replaceTunnels(user, keyfile string, newAddrs []string) error { - glog.Infof("replacing tunnels. New addrs: %v", newAddrs) - tunnels := util.MakeSSHTunnels(user, keyfile, newAddrs) - if err := tunnels.Open(); err != nil { - return err - } - m.tunnelsLock.Lock() - defer m.tunnelsLock.Unlock() - if m.tunnels != nil { - m.tunnels.Close() - } - m.tunnels = tunnels - atomic.StoreInt64(&m.lastSync, m.clock.Now().Unix()) - return nil -} - -func (m *Master) loadTunnels(user, keyfile string) error { - addrs, err := m.getNodeAddresses() - if err != nil { - return err - } - if !m.needToReplaceTunnels(addrs) { - return nil - } - // TODO: This is going to unnecessarily close connections to unchanged nodes. - // See comment about using Watch above. - glog.Info("found different nodes. Need to replace tunnels") - return m.replaceTunnels(user, keyfile, addrs) -} - -func (m *Master) refreshTunnels(user, keyfile string) error { - addrs, err := m.getNodeAddresses() - if err != nil { - return err - } - return m.replaceTunnels(user, keyfile, addrs) -} - -func (m *Master) setupSecureProxy(user, privateKeyfile, publicKeyfile string) { - // Sync loop to ensure that the SSH key has been installed. - go util.Until(func() { - if m.installSSHKey == nil { - glog.Error("Won't attempt to install ssh key: installSSHKey function is nil") - return - } - key, err := util.ParsePublicKeyFromFile(publicKeyfile) - if err != nil { - glog.Errorf("Failed to load public key: %v", err) - return - } - keyData, err := util.EncodeSSHKey(key) - if err != nil { - glog.Errorf("Failed to encode public key: %v", err) - return - } - if err := m.installSSHKey(user, keyData); err != nil { - glog.Errorf("Failed to install ssh key: %v", err) - } - }, 5*time.Minute, util.NeverStop) - // Sync loop for tunnels - // TODO: switch this to watch. - go util.Until(func() { - if err := m.loadTunnels(user, privateKeyfile); err != nil { - glog.Errorf("Failed to load SSH Tunnels: %v", err) - } - if m.tunnels != nil && m.tunnels.Len() != 0 { - // Sleep for 10 seconds if we have some tunnels. - // TODO (cjcullen): tunnels can lag behind actually existing nodes. - time.Sleep(9 * time.Second) - } - }, 1*time.Second, util.NeverStop) - // Refresh loop for tunnels - // TODO: could make this more controller-ish - go util.Until(func() { - time.Sleep(5 * time.Minute) - if err := m.refreshTunnels(user, privateKeyfile); err != nil { - glog.Errorf("Failed to refresh SSH Tunnels: %v", err) - } - }, 0*time.Second, util.NeverStop) -} - -func (m *Master) generateSSHKey(user, privateKeyfile, publicKeyfile string) error { - // TODO: user is not used. Consider removing it as an input to the function. - private, public, err := util.GenerateKey(2048) - if err != nil { - return err - } - // If private keyfile already exists, we must have only made it halfway - // through last time, so delete it. - exists, err := util.FileExists(privateKeyfile) - if err != nil { - glog.Errorf("Error detecting if private key exists: %v", err) - } else if exists { - glog.Infof("Private key exists, but public key does not") - if err := os.Remove(privateKeyfile); err != nil { - glog.Errorf("Failed to remove stale private key: %v", err) - } - } - if err := ioutil.WriteFile(privateKeyfile, util.EncodePrivateKey(private), 0600); err != nil { - return err - } - publicKeyBytes, err := util.EncodePublicKey(public) - if err != nil { - return err - } - if err := ioutil.WriteFile(publicKeyfile+".tmp", publicKeyBytes, 0600); err != nil { - return err - } - return os.Rename(publicKeyfile+".tmp", publicKeyfile) -} diff --git a/pkg/master/master_test.go b/pkg/master/master_test.go index f00f76b9de0..bdb9a79928b 100644 --- a/pkg/master/master_test.go +++ b/pkg/master/master_test.go @@ -18,6 +18,7 @@ package master import ( "bytes" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -25,12 +26,9 @@ import ( "net" "net/http" "net/http/httptest" - "os" - "path/filepath" "reflect" "strings" "testing" - "time" "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/api/latest" @@ -81,7 +79,12 @@ func setUp(t *testing.T) (Master, Config, *assert.Assertions) { // using the configuration properly. func TestNew(t *testing.T) { _, config, assert := setUp(t) + config.KubeletClient = client.FakeKubeletClient{} + + config.ProxyDialer = func(network, addr string) (net.Conn, error) { return nil, nil } + config.ProxyTLSClientConfig = &tls.Config{} + master := New(&config) // Verify many of the variables match their config counterparts @@ -106,7 +109,15 @@ func TestNew(t *testing.T) { assert.Equal(master.clusterIP, config.PublicAddress) assert.Equal(master.publicReadWritePort, config.ReadWritePort) assert.Equal(master.serviceReadWriteIP, config.ServiceReadWriteIP) - assert.Equal(master.installSSHKey, config.InstallSSHKey) + assert.Equal(master.tunneler, config.Tunneler) + + // These functions should point to the same memory location + masterDialer, _ := util.Dialer(master.proxyTransport) + masterDialerFunc := fmt.Sprintf("%p", masterDialer) + configDialerFunc := fmt.Sprintf("%p", config.ProxyDialer) + assert.Equal(masterDialerFunc, configDialerFunc) + + assert.Equal(master.proxyTransport.(*http.Transport).TLSClientConfig, config.ProxyTLSClientConfig) } // TestNewEtcdStorage verifies that the usage of NewEtcdStorage reacts properly when @@ -271,7 +282,6 @@ func TestInstallSwaggerAPI(t *testing.T) { // creates the expected APIGroupVersion based off of master. func TestDefaultAPIGroupVersion(t *testing.T) { master, _, assert := setUp(t) - master.dialer = func(network, addr string) (net.Conn, error) { return nil, nil } apiGroup := master.defaultAPIGroupVersion() @@ -279,11 +289,6 @@ func TestDefaultAPIGroupVersion(t *testing.T) { assert.Equal(apiGroup.Admit, master.admissionControl) assert.Equal(apiGroup.Context, master.requestContextMapper) assert.Equal(apiGroup.MinRequestTimeout, master.minRequestTimeout) - - // These functions should be different instances of the same function - groupDialerFunc := fmt.Sprintf("%+v", apiGroup.ProxyDialerFn) - masterDialerFunc := fmt.Sprintf("%+v", master.dialer) - assert.Equal(groupDialerFunc, masterDialerFunc) } // TestExpapi verifies that the unexported exapi creates @@ -299,42 +304,6 @@ func TestExpapi(t *testing.T) { assert.Equal(expAPIGroup.Version, latest.GroupOrDie("extensions").GroupVersion) } -// TestSecondsSinceSync verifies that proper results are returned -// when checking the time between syncs -func TestSecondsSinceSync(t *testing.T) { - master, _, assert := setUp(t) - master.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix() - - // Nano Second. No difference. - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 2, time.UTC)} - assert.Equal(int64(0), master.secondsSinceSync()) - - // Second - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 2, 1, time.UTC)} - assert.Equal(int64(1), master.secondsSinceSync()) - - // Minute - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 2, 1, 1, time.UTC)} - assert.Equal(int64(60), master.secondsSinceSync()) - - // Hour - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 2, 1, 1, 1, time.UTC)} - assert.Equal(int64(3600), master.secondsSinceSync()) - - // Day - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 2, 1, 1, 1, 1, time.UTC)} - assert.Equal(int64(86400), master.secondsSinceSync()) - - // Month - master.clock = &util.FakeClock{Time: time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC)} - assert.Equal(int64(2678400), master.secondsSinceSync()) - - // Future Month. Should be -Month. - master.lastSync = time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC).Unix() - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC)} - assert.Equal(int64(-2678400), master.secondsSinceSync()) -} - // TestGetNodeAddresses verifies that proper results are returned // when requesting node addresses. func TestGetNodeAddresses(t *testing.T) { @@ -366,73 +335,6 @@ func TestGetNodeAddresses(t *testing.T) { assert.Equal([]string{"127.0.0.2", "127.0.0.2"}, addrs) } -// TestRefreshTunnels verifies that the function errors when no addresses -// are associated with nodes -func TestRefreshTunnels(t *testing.T) { - master, _, assert := setUp(t) - - // Fail case (no addresses associated with nodes) - assert.Error(master.refreshTunnels("test", "/tmp/undefined")) - - // TODO: pass case without needing actual connections? -} - -// TestIsTunnelSyncHealthy verifies that the 600 second lag test -// is honored. -func TestIsTunnelSyncHealthy(t *testing.T) { - master, _, assert := setUp(t) - - // Pass case: 540 second lag - master.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix() - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 9, 1, 1, time.UTC)} - err := master.IsTunnelSyncHealthy(nil) - assert.NoError(err, "IsTunnelSyncHealthy() should not have returned an error.") - - // Fail case: 720 second lag - master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 12, 1, 1, time.UTC)} - err = master.IsTunnelSyncHealthy(nil) - assert.Error(err, "IsTunnelSyncHealthy() should have returned an error.") -} - -// generateTempFile creates a temporary file path -func generateTempFilePath(prefix string) string { - tmpPath, _ := filepath.Abs(fmt.Sprintf("%s/%s-%d", os.TempDir(), prefix, time.Now().Unix())) - return tmpPath -} - -// TestGenerateSSHKey verifies that SSH key generation does indeed -// generate keys even with keys already exist. -func TestGenerateSSHKey(t *testing.T) { - master, _, assert := setUp(t) - - privateKey := generateTempFilePath("private") - publicKey := generateTempFilePath("public") - - // Make sure we have no test keys laying around - os.Remove(privateKey) - os.Remove(publicKey) - - // Pass case: Sunny day case - err := master.generateSSHKey("unused", privateKey, publicKey) - assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) - - // Pass case: PrivateKey exists test case - os.Remove(publicKey) - err = master.generateSSHKey("unused", privateKey, publicKey) - assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) - - // Pass case: PublicKey exists test case - os.Remove(privateKey) - err = master.generateSSHKey("unused", privateKey, publicKey) - assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) - - // Make sure we have no test keys laying around - os.Remove(privateKey) - os.Remove(publicKey) - - // TODO: testing error cases where the file can not be removed? -} - func TestDiscoveryAtAPIS(t *testing.T) { master, config, assert := setUp(t) master.exp = true diff --git a/pkg/master/tunneler.go b/pkg/master/tunneler.go new file mode 100644 index 00000000000..f07ddde0785 --- /dev/null +++ b/pkg/master/tunneler.go @@ -0,0 +1,262 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package master + +import ( + "io/ioutil" + "math/rand" + "net" + "os" + "sync" + "sync/atomic" + "time" + + "k8s.io/kubernetes/pkg/util" + + "github.com/golang/glog" + "github.com/prometheus/client_golang/prometheus" +) + +type AddressFunc func() (addresses []string, err error) + +type Tunneler interface { + Run(AddressFunc) + Stop() + Dial(net, addr string) (net.Conn, error) + SecondsSinceSync() int64 +} + +type SSHTunneler struct { + SSHUser string + SSHKeyfile string + InstallSSHKey InstallSSHKey + + tunnels *util.SSHTunnelList + tunnelsLock sync.Mutex + lastSync int64 // Seconds since Epoch + lastSyncMetric prometheus.GaugeFunc + clock util.Clock + + getAddresses AddressFunc + stopChan chan struct{} +} + +func NewSSHTunneler(sshUser string, sshKeyfile string, installSSHKey InstallSSHKey) Tunneler { + return &SSHTunneler{ + SSHUser: sshUser, + SSHKeyfile: sshKeyfile, + InstallSSHKey: installSSHKey, + + clock: util.RealClock{}, + } +} + +// Run establishes tunnel loops and returns +func (c *SSHTunneler) Run(getAddresses AddressFunc) { + if c.stopChan != nil { + return + } + c.stopChan = make(chan struct{}) + + // Save the address getter + if getAddresses != nil { + c.getAddresses = getAddresses + } + + // Usernames are capped @ 32 + if len(c.SSHUser) > 32 { + glog.Warning("SSH User is too long, truncating to 32 chars") + c.SSHUser = c.SSHUser[0:32] + } + glog.Infof("Setting up proxy: %s %s", c.SSHUser, c.SSHKeyfile) + + // public keyfile is written last, so check for that. + publicKeyFile := c.SSHKeyfile + ".pub" + exists, err := util.FileExists(publicKeyFile) + if err != nil { + glog.Errorf("Error detecting if key exists: %v", err) + } else if !exists { + glog.Infof("Key doesn't exist, attempting to create") + err := c.generateSSHKey(c.SSHUser, c.SSHKeyfile, publicKeyFile) + if err != nil { + glog.Errorf("Failed to create key pair: %v", err) + } + } + c.tunnels = &util.SSHTunnelList{} + c.setupSecureProxy(c.SSHUser, c.SSHKeyfile, publicKeyFile) + c.lastSync = c.clock.Now().Unix() +} + +// Stop gracefully shuts down the tunneler +func (c *SSHTunneler) Stop() { + if c.stopChan != nil { + close(c.stopChan) + c.stopChan = nil + } +} + +func (c *SSHTunneler) Dial(net, addr string) (net.Conn, error) { + // Only lock while picking a tunnel. + tunnel, err := func() (util.SSHTunnelEntry, error) { + c.tunnelsLock.Lock() + defer c.tunnelsLock.Unlock() + return c.tunnels.PickRandomTunnel() + }() + if err != nil { + return nil, err + } + + start := time.Now() + id := rand.Int63() // So you can match begins/ends in the log. + glog.V(3).Infof("[%x: %v] Dialing...", id, tunnel.Address) + defer func() { + glog.V(3).Infof("[%x: %v] Dialed in %v.", id, tunnel.Address, time.Now().Sub(start)) + }() + return tunnel.Tunnel.Dial(net, addr) +} + +func (c *SSHTunneler) SecondsSinceSync() int64 { + now := c.clock.Now().Unix() + then := atomic.LoadInt64(&c.lastSync) + return now - then +} + +func (c *SSHTunneler) needToReplaceTunnels(addrs []string) bool { + c.tunnelsLock.Lock() + defer c.tunnelsLock.Unlock() + if c.tunnels == nil || c.tunnels.Len() != len(addrs) { + return true + } + // TODO (cjcullen): This doesn't need to be n^2 + for ix := range addrs { + if !c.tunnels.Has(addrs[ix]) { + return true + } + } + return false +} + +func (c *SSHTunneler) replaceTunnels(user, keyfile string, newAddrs []string) error { + glog.Infof("replacing tunnels. New addrs: %v", newAddrs) + tunnels := util.MakeSSHTunnels(user, keyfile, newAddrs) + if err := tunnels.Open(); err != nil { + return err + } + c.tunnelsLock.Lock() + defer c.tunnelsLock.Unlock() + if c.tunnels != nil { + c.tunnels.Close() + } + c.tunnels = tunnels + atomic.StoreInt64(&c.lastSync, c.clock.Now().Unix()) + return nil +} + +func (c *SSHTunneler) loadTunnels(user, keyfile string) error { + addrs, err := c.getAddresses() + if err != nil { + return err + } + if !c.needToReplaceTunnels(addrs) { + return nil + } + // TODO: This is going to unnecessarily close connections to unchanged nodes. + // See comment about using Watch above. + glog.Info("found different nodes. Need to replace tunnels") + return c.replaceTunnels(user, keyfile, addrs) +} + +func (c *SSHTunneler) refreshTunnels(user, keyfile string) error { + addrs, err := c.getAddresses() + if err != nil { + return err + } + return c.replaceTunnels(user, keyfile, addrs) +} + +func (c *SSHTunneler) setupSecureProxy(user, privateKeyfile, publicKeyfile string) { + // Sync loop to ensure that the SSH key has been installed. + go util.Until(func() { + if c.InstallSSHKey == nil { + glog.Error("Won't attempt to install ssh key: InstallSSHKey function is nil") + return + } + key, err := util.ParsePublicKeyFromFile(publicKeyfile) + if err != nil { + glog.Errorf("Failed to load public key: %v", err) + return + } + keyData, err := util.EncodeSSHKey(key) + if err != nil { + glog.Errorf("Failed to encode public key: %v", err) + return + } + if err := c.InstallSSHKey(user, keyData); err != nil { + glog.Errorf("Failed to install ssh key: %v", err) + } + }, 5*time.Minute, c.stopChan) + // Sync loop for tunnels + // TODO: switch this to watch. + go util.Until(func() { + if err := c.loadTunnels(user, privateKeyfile); err != nil { + glog.Errorf("Failed to load SSH Tunnels: %v", err) + } + if c.tunnels != nil && c.tunnels.Len() != 0 { + // Sleep for 10 seconds if we have some tunnels. + // TODO (cjcullen): tunnels can lag behind actually existing nodes. + time.Sleep(9 * time.Second) + } + }, 1*time.Second, c.stopChan) + // Refresh loop for tunnels + // TODO: could make this more controller-ish + go util.Until(func() { + time.Sleep(5 * time.Minute) + if err := c.refreshTunnels(user, privateKeyfile); err != nil { + glog.Errorf("Failed to refresh SSH Tunnels: %v", err) + } + }, 0*time.Second, c.stopChan) +} + +func (c *SSHTunneler) generateSSHKey(user, privateKeyfile, publicKeyfile string) error { + // TODO: user is not used. Consider removing it as an input to the function. + private, public, err := util.GenerateKey(2048) + if err != nil { + return err + } + // If private keyfile already exists, we must have only made it halfway + // through last time, so delete it. + exists, err := util.FileExists(privateKeyfile) + if err != nil { + glog.Errorf("Error detecting if private key exists: %v", err) + } else if exists { + glog.Infof("Private key exists, but public key does not") + if err := os.Remove(privateKeyfile); err != nil { + glog.Errorf("Failed to remove stale private key: %v", err) + } + } + if err := ioutil.WriteFile(privateKeyfile, util.EncodePrivateKey(private), 0600); err != nil { + return err + } + publicKeyBytes, err := util.EncodePublicKey(public) + if err != nil { + return err + } + if err := ioutil.WriteFile(publicKeyfile+".tmp", publicKeyBytes, 0600); err != nil { + return err + } + return os.Rename(publicKeyfile+".tmp", publicKeyfile) +} diff --git a/pkg/master/tunneler_test.go b/pkg/master/tunneler_test.go new file mode 100644 index 00000000000..7b348a525b3 --- /dev/null +++ b/pkg/master/tunneler_test.go @@ -0,0 +1,139 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package master + +import ( + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "k8s.io/kubernetes/pkg/util" + + "github.com/stretchr/testify/assert" +) + +// TestSecondsSinceSync verifies that proper results are returned +// when checking the time between syncs +func TestSecondsSinceSync(t *testing.T) { + tunneler := &SSHTunneler{} + assert := assert.New(t) + + tunneler.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix() + + // Nano Second. No difference. + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 2, time.UTC)} + assert.Equal(int64(0), tunneler.SecondsSinceSync()) + + // Second + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 2, 1, time.UTC)} + assert.Equal(int64(1), tunneler.SecondsSinceSync()) + + // Minute + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 2, 1, 1, time.UTC)} + assert.Equal(int64(60), tunneler.SecondsSinceSync()) + + // Hour + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 2, 1, 1, 1, time.UTC)} + assert.Equal(int64(3600), tunneler.SecondsSinceSync()) + + // Day + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 2, 1, 1, 1, 1, time.UTC)} + assert.Equal(int64(86400), tunneler.SecondsSinceSync()) + + // Month + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC)} + assert.Equal(int64(2678400), tunneler.SecondsSinceSync()) + + // Future Month. Should be -Month. + tunneler.lastSync = time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC).Unix() + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC)} + assert.Equal(int64(-2678400), tunneler.SecondsSinceSync()) +} + +// TestRefreshTunnels verifies that the function errors when no addresses +// are associated with nodes +func TestRefreshTunnels(t *testing.T) { + tunneler := &SSHTunneler{} + tunneler.getAddresses = func() ([]string, error) { return []string{}, nil } + assert := assert.New(t) + + // Fail case (no addresses associated with nodes) + assert.Error(tunneler.refreshTunnels("test", "/tmp/undefined")) + + // TODO: pass case without needing actual connections? +} + +// TestIsTunnelSyncHealthy verifies that the 600 second lag test +// is honored. +func TestIsTunnelSyncHealthy(t *testing.T) { + tunneler := &SSHTunneler{} + master, _, assert := setUp(t) + master.tunneler = tunneler + + // Pass case: 540 second lag + tunneler.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix() + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 9, 1, 1, time.UTC)} + err := master.IsTunnelSyncHealthy(nil) + assert.NoError(err, "IsTunnelSyncHealthy() should not have returned an error.") + + // Fail case: 720 second lag + tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 12, 1, 1, time.UTC)} + err = master.IsTunnelSyncHealthy(nil) + assert.Error(err, "IsTunnelSyncHealthy() should have returned an error.") +} + +// generateTempFile creates a temporary file path +func generateTempFilePath(prefix string) string { + tmpPath, _ := filepath.Abs(fmt.Sprintf("%s/%s-%d", os.TempDir(), prefix, time.Now().Unix())) + return tmpPath +} + +// TestGenerateSSHKey verifies that SSH key generation does indeed +// generate keys even with keys already exist. +func TestGenerateSSHKey(t *testing.T) { + tunneler := &SSHTunneler{} + assert := assert.New(t) + + privateKey := generateTempFilePath("private") + publicKey := generateTempFilePath("public") + + // Make sure we have no test keys laying around + os.Remove(privateKey) + os.Remove(publicKey) + + // Pass case: Sunny day case + err := tunneler.generateSSHKey("unused", privateKey, publicKey) + assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) + + // Pass case: PrivateKey exists test case + os.Remove(publicKey) + err = tunneler.generateSSHKey("unused", privateKey, publicKey) + assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) + + // Pass case: PublicKey exists test case + os.Remove(privateKey) + err = tunneler.generateSSHKey("unused", privateKey, publicKey) + assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err) + + // Make sure we have no test keys laying around + os.Remove(privateKey) + os.Remove(publicKey) + + // TODO: testing error cases where the file can not be removed? +} diff --git a/pkg/registry/generic/rest/proxy.go b/pkg/registry/generic/rest/proxy.go index 14bd0e0e27d..cc4d74440f7 100644 --- a/pkg/registry/generic/rest/proxy.go +++ b/pkg/registry/generic/rest/proxy.go @@ -17,10 +17,7 @@ limitations under the License. package rest import ( - "crypto/tls" - "fmt" "io" - "net" "net/http" "net/http/httputil" "net/url" @@ -29,12 +26,12 @@ import ( "time" "k8s.io/kubernetes/pkg/api/errors" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/proxy" "github.com/golang/glog" "github.com/mxk/go-flowrate/flowrate" - "k8s.io/kubernetes/third_party/golang/netutil" ) // UpgradeAwareProxyHandler is a handler for proxy requests that may require an upgrade @@ -128,7 +125,7 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return false } - backendConn, err := h.dialURL() + backendConn, err := proxy.DialURL(h.Location, h.Transport) if err != nil { h.err = err return true @@ -189,79 +186,6 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R return true } -func (h *UpgradeAwareProxyHandler) dialURL() (net.Conn, error) { - dialAddr := netutil.CanonicalAddr(h.Location) - - var dialer func(network, addr string) (net.Conn, error) - if httpTransport, ok := h.Transport.(*http.Transport); ok && httpTransport.Dial != nil { - dialer = httpTransport.Dial - } - - switch h.Location.Scheme { - case "http": - if dialer != nil { - return dialer("tcp", dialAddr) - } - return net.Dial("tcp", dialAddr) - case "https": - // TODO: this TLS logic can probably be cleaned up; it's messy in an attempt - // to preserve behavior that we don't know for sure is exercised. - - // Get the tls config from the transport if we recognize it - var tlsConfig *tls.Config - var tlsConn *tls.Conn - var err error - if h.Transport != nil { - httpTransport, ok := h.Transport.(*http.Transport) - if ok { - tlsConfig = httpTransport.TLSClientConfig - } - } - if dialer != nil { - // We have a dialer; use it to open the connection, then - // create a tls client using the connection. - netConn, err := dialer("tcp", dialAddr) - if err != nil { - return nil, err - } - // tls.Client requires non-nil config - if tlsConfig == nil { - glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify") - tlsConfig = &tls.Config{ - InsecureSkipVerify: true, - } - } - tlsConn = tls.Client(netConn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { - return nil, err - } - - } else { - // Dial - tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig) - if err != nil { - return nil, err - } - } - - // Return if we were configured to skip validation - if tlsConfig != nil && tlsConfig.InsecureSkipVerify { - return tlsConn, nil - } - - // Verify - host, _, _ := net.SplitHostPort(dialAddr) - if err := tlsConn.VerifyHostname(host); err != nil { - tlsConn.Close() - return nil, err - } - - return tlsConn, nil - default: - return nil, fmt.Errorf("unknown scheme: %s", h.Location.Scheme) - } -} - func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper { scheme := url.Scheme host := url.Host @@ -294,7 +218,12 @@ func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, er } removeCORSHeaders(resp) return resp, nil +} +var _ = util.RoundTripperWrapper(&corsRemovingTransport{}) + +func (rt *corsRemovingTransport) WrappedRoundTripper() http.RoundTripper { + return rt.RoundTripper } // removeCORSHeaders strip CORS headers sent from the backend diff --git a/pkg/registry/generic/rest/proxy_test.go b/pkg/registry/generic/rest/proxy_test.go index 1f20f3dac87..829b1a6d5ef 100644 --- a/pkg/registry/generic/rest/proxy_test.go +++ b/pkg/registry/generic/rest/proxy_test.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "io" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" @@ -305,6 +306,21 @@ func TestProxyUpgrade(t *testing.T) { }, ProxyTransport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: localhostPool}}, }, + "https (valid hostname + RootCAs + custom dialer)": { + ServerFunc: func(h http.Handler) *httptest.Server { + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Errorf("https (valid hostname): proxy_test: %v", err) + } + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + ts.StartTLS() + return ts + }, + ProxyTransport: &http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}, + }, } for k, tc := range testcases { diff --git a/pkg/registry/node/etcd/etcd.go b/pkg/registry/node/etcd/etcd.go index ec66b5cf46f..2b83d10f08b 100644 --- a/pkg/registry/node/etcd/etcd.go +++ b/pkg/registry/node/etcd/etcd.go @@ -31,7 +31,8 @@ import ( type REST struct { *etcdgeneric.Etcd - connection client.ConnectionInfoGetter + connection client.ConnectionInfoGetter + proxyTransport http.RoundTripper } // StatusREST implements the REST endpoint for changing the status of a pod. @@ -49,7 +50,7 @@ func (r *StatusREST) Update(ctx api.Context, obj runtime.Object) (runtime.Object } // NewREST returns a RESTStorage object that will work against nodes. -func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionInfoGetter) (*REST, *StatusREST) { +func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionInfoGetter, proxyTransport http.RoundTripper) (*REST, *StatusREST) { prefix := "/minions" storageInterface := s @@ -91,7 +92,7 @@ func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionIn statusStore := *store statusStore.UpdateStrategy = node.StatusStrategy - return &REST{store, connection}, &StatusREST{store: &statusStore} + return &REST{store, connection, proxyTransport}, &StatusREST{store: &statusStore} } // Implement Redirector. @@ -99,5 +100,5 @@ var _ = rest.Redirector(&REST{}) // ResourceLocation returns a URL to which one can send traffic for the specified node. func (r *REST) ResourceLocation(ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { - return node.ResourceLocation(r, r.connection, ctx, id) + return node.ResourceLocation(r, r.connection, r.proxyTransport, ctx, id) } diff --git a/pkg/registry/node/etcd/etcd_test.go b/pkg/registry/node/etcd/etcd_test.go index af23cb0c1da..120735e1f09 100644 --- a/pkg/registry/node/etcd/etcd_test.go +++ b/pkg/registry/node/etcd/etcd_test.go @@ -38,7 +38,7 @@ func (fakeConnectionInfoGetter) GetConnectionInfo(host string) (string, uint, ht func newStorage(t *testing.T) (*REST, *tools.FakeEtcdClient) { etcdStorage, fakeClient := registrytest.NewEtcdStorage(t, "") - storage, _ := NewREST(etcdStorage, false, fakeConnectionInfoGetter{}) + storage, _ := NewREST(etcdStorage, false, fakeConnectionInfoGetter{}, nil) return storage, fakeClient } diff --git a/pkg/registry/node/strategy.go b/pkg/registry/node/strategy.go index 76f77e0239d..679a53b8fa9 100644 --- a/pkg/registry/node/strategy.go +++ b/pkg/registry/node/strategy.go @@ -136,7 +136,7 @@ func MatchNode(label labels.Selector, field fields.Selector) generic.Matcher { } // ResourceLocation returns an URL and transport which one can use to send traffic for the specified node. -func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGetter, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { +func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGetter, proxyTransport http.RoundTripper, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { schemeReq, name, portReq, valid := util.SplitSchemeNamePort(id) if !valid { return nil, nil, errors.NewBadRequest(fmt.Sprintf("invalid node request %q", id)) @@ -155,7 +155,7 @@ func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGet if portReq == "" || strconv.Itoa(ports.KubeletPort) == portReq { // Ignore requested scheme, use scheme provided by GetConnectionInfo - scheme, port, transport, err := connection.GetConnectionInfo(host) + scheme, port, kubeletTransport, err := connection.GetConnectionInfo(host) if err != nil { return nil, nil, err } @@ -166,8 +166,8 @@ func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGet strconv.FormatUint(uint64(port), 10), ), }, - transport, + kubeletTransport, nil } - return &url.URL{Scheme: schemeReq, Host: net.JoinHostPort(host, portReq)}, nil, nil + return &url.URL{Scheme: schemeReq, Host: net.JoinHostPort(host, portReq)}, proxyTransport, nil } diff --git a/pkg/registry/pod/etcd/etcd.go b/pkg/registry/pod/etcd/etcd.go index 5b71cd782e0..9faa40abf62 100644 --- a/pkg/registry/pod/etcd/etcd.go +++ b/pkg/registry/pod/etcd/etcd.go @@ -56,10 +56,11 @@ type PodStorage struct { // REST implements a RESTStorage for pods against etcd type REST struct { *etcdgeneric.Etcd + proxyTransport http.RoundTripper } // NewStorage returns a RESTStorage object that will work against pods. -func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGetter) PodStorage { +func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGetter, proxyTransport http.RoundTripper) PodStorage { prefix := "/pods" storageInterface := s @@ -106,11 +107,11 @@ func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGett statusStore.UpdateStrategy = pod.StatusStrategy return PodStorage{ - Pod: &REST{store}, + Pod: &REST{store, proxyTransport}, Binding: &BindingREST{store: store}, Status: &StatusREST{store: &statusStore}, Log: &LogREST{store: store, kubeletConn: k}, - Proxy: &ProxyREST{store: store}, + Proxy: &ProxyREST{store: store, proxyTransport: proxyTransport}, Exec: &ExecREST{store: store, kubeletConn: k}, Attach: &AttachREST{store: store, kubeletConn: k}, PortForward: &PortForwardREST{store: store, kubeletConn: k}, @@ -122,7 +123,7 @@ var _ = rest.Redirector(&REST{}) // ResourceLocation returns a pods location from its HostIP func (r *REST) ResourceLocation(ctx api.Context, name string) (*url.URL, http.RoundTripper, error) { - return pod.ResourceLocation(r, ctx, name) + return pod.ResourceLocation(r, r.proxyTransport, ctx, name) } // BindingREST implements the REST endpoint for binding pods to nodes when etcd is in use. @@ -256,7 +257,8 @@ func (r *LogREST) NewGetOptions() (runtime.Object, bool, string) { // ProxyREST implements the proxy subresource for a Pod // TODO: move me into pod/rest - I'm generic to store type via ResourceGetter type ProxyREST struct { - store *etcdgeneric.Etcd + store *etcdgeneric.Etcd + proxyTransport http.RoundTripper } // Implement Connecter @@ -285,7 +287,7 @@ func (r *ProxyREST) Connect(ctx api.Context, id string, opts runtime.Object) (re if !ok { return nil, fmt.Errorf("Invalid options object: %#v", opts) } - location, transport, err := pod.ResourceLocation(r.store, ctx, id) + location, transport, err := pod.ResourceLocation(r.store, r.proxyTransport, ctx, id) if err != nil { return nil, err } diff --git a/pkg/registry/pod/etcd/etcd_test.go b/pkg/registry/pod/etcd/etcd_test.go index 0c4fd6f6e36..19cc073954b 100644 --- a/pkg/registry/pod/etcd/etcd_test.go +++ b/pkg/registry/pod/etcd/etcd_test.go @@ -38,7 +38,7 @@ import ( func newStorage(t *testing.T) (*REST, *BindingREST, *StatusREST, *tools.FakeEtcdClient) { etcdStorage, fakeClient := registrytest.NewEtcdStorage(t, "") - storage := NewStorage(etcdStorage, false, nil) + storage := NewStorage(etcdStorage, false, nil, nil) return storage.Pod, storage.Binding, storage.Status, fakeClient } @@ -740,7 +740,7 @@ func TestEtcdUpdateStatus(t *testing.T) { func TestPodLogValidates(t *testing.T) { etcdStorage, _ := registrytest.NewEtcdStorage(t, "") - storage := NewStorage(etcdStorage, false, nil) + storage := NewStorage(etcdStorage, false, nil, nil) negativeOne := int64(-1) testCases := []*api.PodLogOptions{ diff --git a/pkg/registry/pod/strategy.go b/pkg/registry/pod/strategy.go index 4029715b977..82ac5c49b0f 100644 --- a/pkg/registry/pod/strategy.go +++ b/pkg/registry/pod/strategy.go @@ -17,7 +17,6 @@ limitations under the License. package pod import ( - "crypto/tls" "fmt" "net" "net/http" @@ -47,13 +46,6 @@ type podStrategy struct { // objects via the REST API. var Strategy = podStrategy{api.Scheme, api.SimpleNameGenerator} -// PodProxyTransport is used by the API proxy to connect to pods -// Exported to allow overriding TLS options (like adding a client certificate) -var PodProxyTransport = util.SetTransportDefaults(&http.Transport{ - // Turn off hostname verification, because connections are to assigned IPs, not deterministic - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, -}) - // NamespaceScoped is true for pods. func (podStrategy) NamespaceScoped() bool { return true @@ -195,7 +187,7 @@ func getPod(getter ResourceGetter, ctx api.Context, name string) (*api.Pod, erro } // ResourceLocation returns a URL to which one can send traffic for the specified pod. -func ResourceLocation(getter ResourceGetter, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { +func ResourceLocation(getter ResourceGetter, rt http.RoundTripper, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) { // Allow ID as "podname" or "podname:port" or "scheme:podname:port". // If port is not specified, try to use the first defined port on the pod. scheme, name, port, valid := util.SplitSchemeNamePort(id) @@ -227,7 +219,7 @@ func ResourceLocation(getter ResourceGetter, ctx api.Context, id string) (*url.U } else { loc.Host = net.JoinHostPort(pod.Status.PodIP, port) } - return loc, PodProxyTransport, nil + return loc, rt, nil } // LogLocation returns the log URL for a pod container. If opts.Container is blank diff --git a/pkg/registry/service/rest.go b/pkg/registry/service/rest.go index 3b17d72babf..3b524a2bc7f 100644 --- a/pkg/registry/service/rest.go +++ b/pkg/registry/service/rest.go @@ -17,7 +17,6 @@ limitations under the License. package service import ( - "crypto/tls" "fmt" "math/rand" "net" @@ -48,23 +47,18 @@ type REST struct { endpoints endpoint.Registry serviceIPs ipallocator.Interface serviceNodePorts portallocator.Interface + proxyTransport http.RoundTripper } -// ServiceProxyTransport is used by the API proxy to connect to services -// Exported to allow overriding TLS options (like adding a client certificate) -var ServiceProxyTransport = util.SetTransportDefaults(&http.Transport{ - // Turn off hostname verification, because connections are to assigned IPs, not deterministic - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, -}) - // NewStorage returns a new REST. func NewStorage(registry Registry, endpoints endpoint.Registry, serviceIPs ipallocator.Interface, - serviceNodePorts portallocator.Interface) *REST { + serviceNodePorts portallocator.Interface, proxyTransport http.RoundTripper) *REST { return &REST{ registry: registry, endpoints: endpoints, serviceIPs: serviceIPs, serviceNodePorts: serviceNodePorts, + proxyTransport: proxyTransport, } } @@ -314,7 +308,7 @@ func (rs *REST) ResourceLocation(ctx api.Context, id string) (*url.URL, http.Rou return &url.URL{ Scheme: svcScheme, Host: net.JoinHostPort(ip, strconv.Itoa(port)), - }, ServiceProxyTransport, nil + }, rs.proxyTransport, nil } } } diff --git a/pkg/registry/service/rest_test.go b/pkg/registry/service/rest_test.go index 288b1d1257c..3fe14916ee7 100644 --- a/pkg/registry/service/rest_test.go +++ b/pkg/registry/service/rest_test.go @@ -46,7 +46,7 @@ func NewTestREST(t *testing.T, endpoints *api.EndpointsList) (*REST, *registryte portRange := util.PortRange{Base: 30000, Size: 1000} portAllocator := portallocator.NewPortAllocator(portRange) - storage := NewStorage(registry, endpointRegistry, r, portAllocator) + storage := NewStorage(registry, endpointRegistry, r, portAllocator, nil) return storage, registry } diff --git a/pkg/util/http.go b/pkg/util/http.go index e8253fa80f5..8f35ce49bbf 100644 --- a/pkg/util/http.go +++ b/pkg/util/http.go @@ -17,7 +17,10 @@ limitations under the License. package util import ( + "crypto/tls" + "fmt" "io" + "net" "net/http" "net/url" "strings" @@ -62,3 +65,40 @@ func SetTransportDefaults(t *http.Transport) *http.Transport { } return t } + +type RoundTripperWrapper interface { + http.RoundTripper + WrappedRoundTripper() http.RoundTripper +} + +type DialFunc func(net, addr string) (net.Conn, error) + +func Dialer(transport http.RoundTripper) (DialFunc, error) { + if transport == nil { + return nil, nil + } + + switch transport := transport.(type) { + case *http.Transport: + return transport.Dial, nil + case RoundTripperWrapper: + return Dialer(transport.WrappedRoundTripper()) + default: + return nil, fmt.Errorf("unknown transport type: %v", transport) + } +} + +func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) { + if transport == nil { + return nil, nil + } + + switch transport := transport.(type) { + case *http.Transport: + return transport.TLSClientConfig, nil + case RoundTripperWrapper: + return TLSClientConfig(transport.WrappedRoundTripper()) + default: + return nil, fmt.Errorf("unknown transport type: %v", transport) + } +} diff --git a/pkg/util/proxy/dial.go b/pkg/util/proxy/dial.go new file mode 100644 index 00000000000..07982b79386 --- /dev/null +++ b/pkg/util/proxy/dial.go @@ -0,0 +1,106 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + + "github.com/golang/glog" + + "k8s.io/kubernetes/pkg/util" + "k8s.io/kubernetes/third_party/golang/netutil" +) + +func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) { + dialAddr := netutil.CanonicalAddr(url) + + dialer, _ := util.Dialer(transport) + + switch url.Scheme { + case "http": + if dialer != nil { + return dialer("tcp", dialAddr) + } + return net.Dial("tcp", dialAddr) + case "https": + // Get the tls config from the transport if we recognize it + var tlsConfig *tls.Config + var tlsConn *tls.Conn + var err error + tlsConfig, _ = util.TLSClientConfig(transport) + + if dialer != nil { + // We have a dialer; use it to open the connection, then + // create a tls client using the connection. + netConn, err := dialer("tcp", dialAddr) + if err != nil { + return nil, err + } + if tlsConfig == nil { + // tls.Client requires non-nil config + glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify") + // tls.Handshake() requires ServerName or InsecureSkipVerify + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { + // tls.Handshake() requires ServerName or InsecureSkipVerify + // infer the ServerName from the hostname we're connecting to. + inferredHost := dialAddr + if host, _, err := net.SplitHostPort(dialAddr); err == nil { + inferredHost = host + } + // Make a copy to avoid polluting the provided config + tlsConfigCopy := *tlsConfig + tlsConfigCopy.ServerName = inferredHost + tlsConfig = &tlsConfigCopy + } + tlsConn = tls.Client(netConn, tlsConfig) + if err := tlsConn.Handshake(); err != nil { + netConn.Close() + return nil, err + } + + } else { + // Dial + tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig) + if err != nil { + return nil, err + } + } + + // Return if we were configured to skip validation + if tlsConfig != nil && tlsConfig.InsecureSkipVerify { + return tlsConn, nil + } + + // Verify + host, _, _ := net.SplitHostPort(dialAddr) + if err := tlsConn.VerifyHostname(host); err != nil { + tlsConn.Close() + return nil, err + } + + return tlsConn, nil + default: + return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme) + } +} diff --git a/pkg/util/proxy/transport.go b/pkg/util/proxy/transport.go index 80855f46b8b..b3365784524 100644 --- a/pkg/util/proxy/transport.go +++ b/pkg/util/proxy/transport.go @@ -31,6 +31,7 @@ import ( "golang.org/x/net/html" "golang.org/x/net/html/atom" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/sets" ) @@ -118,6 +119,12 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return t.rewriteResponse(req, resp) } +var _ = util.RoundTripperWrapper(&Transport{}) + +func (rt *Transport) WrappedRoundTripper() http.RoundTripper { + return rt.RoundTripper +} + // rewriteURL rewrites a single URL to go through the proxy, if the URL refers // to the same host as sourceURL, which is the page on which the target URL // occurred. If any error occurs (e.g. parsing), it returns targetURL. diff --git a/test/e2e/kubectl.go b/test/e2e/kubectl.go index 88729a403aa..9ee2916b86d 100644 --- a/test/e2e/kubectl.go +++ b/test/e2e/kubectl.go @@ -669,6 +669,9 @@ var _ = Describe("Kubectl client", func() { By("curling proxy /api/ output") localAddr := fmt.Sprintf("http://localhost:%d/api/", port) apiVersions, err := getAPIVersions(localAddr) + if err != nil { + Failf("Expected at least one supported apiversion, got error %v", err) + } if len(apiVersions.Versions) < 1 { Failf("Expected at least one supported apiversion, got %v", apiVersions) }