From 30a89968a49c01f84bf6b61167c1a5401989731f Mon Sep 17 00:00:00 2001 From: Brendan Burns Date: Wed, 27 May 2015 21:38:21 -0700 Subject: [PATCH] Initial proxy tunnelling. --- cmd/kube-apiserver/app/server.go | 6 ++ pkg/apiserver/api_installer.go | 5 +- pkg/apiserver/apiserver.go | 5 +- pkg/apiserver/proxy.go | 8 +- pkg/master/master.go | 102 +++++++++++++++++++- pkg/util/ssh.go | 155 +++++++++++++++++++++++++------ 6 files changed, 244 insertions(+), 37 deletions(-) diff --git a/cmd/kube-apiserver/app/server.go b/cmd/kube-apiserver/app/server.go index 4a6e7322830..dc4b4f6859b 100644 --- a/cmd/kube-apiserver/app/server.go +++ b/cmd/kube-apiserver/app/server.go @@ -97,6 +97,8 @@ type APIServer struct { MaxRequestsInFlight int MinRequestTimeout int LongRunningRequestRE string + SSHUser string + SSHKeyfile string } // NewAPIServer creates a new APIServer object with default parameters @@ -201,6 +203,8 @@ func (s *APIServer) AddFlags(fs *pflag.FlagSet) { fs.IntVar(&s.MaxRequestsInFlight, "max-requests-inflight", 400, "The maximum number of requests in flight at a given time. When the server exceeds this, it rejects requests. Zero for no limit.") fs.IntVar(&s.MinRequestTimeout, "min-request-timeout", 1800, "An optional field indicating the minimum number of seconds a handler must keep a request open before timing it out. Currently only honored by the watch request handler, which picks a randomized value above this number as the connection timeout, to spread out load.") fs.StringVar(&s.LongRunningRequestRE, "long-running-request-regexp", "[.*\\/watch$][^\\/proxy.*]", "A regular expression matching long running requests which should be excluded from maximum inflight request handling.") + fs.StringVar(&s.SSHUser, "ssh-user", "", "If non-empty, use secure SSH proxy to the nodes, using this user name") + fs.StringVar(&s.SSHKeyfile, "ssh-keyfile", "", "If non-empty, use secure SSH proxy to the nodes, using this user keyfile") } // TODO: Longer term we should read this from some config store, rather than a flag. @@ -378,6 +382,8 @@ func (s *APIServer) Run(_ []string) error { ClusterName: s.ClusterName, ExternalHost: s.ExternalHost, MinRequestTimeout: s.MinRequestTimeout, + SSHUser: s.SSHUser, + SSHKeyfile: s.SSHKeyfile, } m := master.New(config) diff --git a/pkg/apiserver/api_installer.go b/pkg/apiserver/api_installer.go index 2fdcaba9073..7c0cc173df5 100644 --- a/pkg/apiserver/api_installer.go +++ b/pkg/apiserver/api_installer.go @@ -18,6 +18,7 @@ package apiserver import ( "fmt" + "net" "net/http" "net/url" gpath "path" @@ -55,14 +56,14 @@ type action struct { var errEmptyName = errors.NewBadRequest("name must be provided") // Installs handlers for API resources. -func (a *APIInstaller) Install() (ws *restful.WebService, errors []error) { +func (a *APIInstaller) Install(proxyDialer func(network, addr string) (net.Conn, error)) (ws *restful.WebService, errors []error) { errors = make([]error, 0) // Create the WebService. ws = a.newWebService() redirectHandler := (&RedirectHandler{a.group.Storage, a.group.Codec, a.group.Context, a.info}) - proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.Storage, a.group.Codec, a.group.Context, a.info}) + proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.Storage, a.group.Codec, a.group.Context, a.info, proxyDialer}) // Register the paths in a deterministic (sorted) order to get a deterministic swagger spec. paths := make([]string, len(a.group.Storage)) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index addf2d447c5..3d7b440b3c3 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "net/http" "path" "strconv" @@ -149,7 +150,7 @@ type RestContainer struct { // InstallREST registers the REST handlers (storage, watch, proxy and redirect) into a restful Container. // It is expected that the provided path root prefix will serve all operations. Root MUST NOT end // in a slash. A restful WebService is created for the group and version. -func (g *APIGroupVersion) InstallREST(container *RestContainer) error { +func (g *APIGroupVersion) InstallREST(container *RestContainer, proxyDialer func(network, addr string) (net.Conn, error)) error { info := &APIRequestInfoResolver{util.NewStringSet(strings.TrimPrefix(g.Root, "/")), g.Mapper} prefix := path.Join(g.Root, g.Version) @@ -159,7 +160,7 @@ func (g *APIGroupVersion) InstallREST(container *RestContainer) error { prefix: prefix, minRequestTimeout: container.MinRequestTimeout, } - ws, registrationErrors := installer.Install() + ws, registrationErrors := installer.Install(proxyDialer) container.Add(ws) return errors.NewAggregate(registrationErrors) } diff --git a/pkg/apiserver/proxy.go b/pkg/apiserver/proxy.go index 8609da4e442..3f5bd826ac0 100644 --- a/pkg/apiserver/proxy.go +++ b/pkg/apiserver/proxy.go @@ -49,6 +49,8 @@ type ProxyHandler struct { codec runtime.Codec context api.RequestContextMapper apiRequestInfoResolver *APIRequestInfoResolver + + dial func(network, addr string) (net.Conn, error) } func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -119,9 +121,9 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { httpCode = http.StatusNotFound return } - // TODO: make this dynamic - location.Host = "localhost" - location.Scheme = "http" + if r.dial != nil { + transport = &http.Transport{Dial: r.dial} + } // Default to http if location.Scheme == "" { diff --git a/pkg/master/master.go b/pkg/master/master.go index 868832ba722..94835933052 100644 --- a/pkg/master/master.go +++ b/pkg/master/master.go @@ -39,6 +39,8 @@ import ( "github.com/GoogleCloudPlatform/kubernetes/pkg/auth/authorizer" "github.com/GoogleCloudPlatform/kubernetes/pkg/auth/handlers" "github.com/GoogleCloudPlatform/kubernetes/pkg/client" + "github.com/GoogleCloudPlatform/kubernetes/pkg/fields" + "github.com/GoogleCloudPlatform/kubernetes/pkg/labels" "github.com/GoogleCloudPlatform/kubernetes/pkg/master/ports" "github.com/GoogleCloudPlatform/kubernetes/pkg/registry/componentstatus" controlleretcd "github.com/GoogleCloudPlatform/kubernetes/pkg/registry/controller/etcd" @@ -143,6 +145,10 @@ type Config struct { // The range of ports to be assigned to services with type=NodePort or greater ServiceNodePortRange util.PortRange + + // Used for secure proxy. If empty, don't use secure proxy. + SSHUser string + SSHKeyfile string } // Master contains state for a Kubernetes cluster master/api server. @@ -196,6 +202,9 @@ type Master struct { // "Outputs" Handler http.Handler InsecureHandler http.Handler + + // Used for secure proxy + tunnels util.SSHTunnelList } // NewEtcdHelper returns an EtcdHelper for the provided arguments or an error if the version @@ -474,15 +483,22 @@ func (m *Master) init(c *Config) { "componentStatuses": componentstatus.NewStorage(func() map[string]apiserver.Server { return m.getServersToValidate(c) }), } + var proxyDialer func(net, addr string) (net.Conn, error) + if len(c.SSHUser) > 0 { + glog.Infof("Setting up proxy: %s %s", c.SSHUser, c.SSHKeyfile) + m.setupSecureProxy(c.SSHUser, c.SSHKeyfile) + proxyDialer = m.Dial + } + apiVersions := []string{} if m.v1beta3 { - if err := m.api_v1beta3().InstallREST(m.handlerContainer); err != nil { + if err := m.api_v1beta3().InstallREST(m.handlerContainer, proxyDialer); err != nil { glog.Fatalf("Unable to setup API v1beta3: %v", err) } apiVersions = append(apiVersions, "v1beta3") } if m.v1 { - if err := m.api_v1().InstallREST(m.handlerContainer); err != nil { + if err := m.api_v1().InstallREST(m.handlerContainer, proxyDialer); err != nil { glog.Fatalf("Unable to setup API v1: %v", err) } apiVersions = append(apiVersions, "v1") @@ -703,3 +719,85 @@ func (m *Master) api_v1() *apiserver.APIGroupVersion { version.Codec = v1.Codec return version } + +func findExternalAddress(node *api.Node) (string, error) { + for ix := range node.Status.Addresses { + addr := &node.Status.Addresses[ix] + if addr.Type == api.NodeExternalIP { + return addr.Address, nil + } + } + return "", fmt.Errorf("Couldn't find external address: %v", node) +} + +func (m *Master) Dial(net, addr string) (net.Conn, error) { + return m.tunnels.Dial(net, addr) +} + +func (m *Master) detectTunnelChanges(addrs []string) bool { + if len(m.tunnels) != len(addrs) { + return true + } + for ix := range addrs { + if !m.tunnels.Has(addrs[ix]) { + return true + } + } + return false +} + +func (m *Master) loadTunnels(user, keyfile string) error { + nodes, err := m.nodeRegistry.ListMinions(api.NewDefaultContext(), labels.Everything(), fields.Everything()) + if err != nil { + return err + } + result := []string{} + for ix := range nodes.Items { + node := &nodes.Items[ix] + addr, err := findExternalAddress(node) + if err != nil { + return err + } + result = append(result, addr) + } + changesExist := m.detectTunnelChanges(result) + if !changesExist { + return nil + } + + // TODO: This is going to drop connections in the middle. See comment about using Watch above. + tunnels, err := util.MakeSSHTunnels(user, keyfile, result) + if err != nil { + return err + } + tunnels.Open() + if m.tunnels != nil { + m.tunnels.Close() + } + m.tunnels = tunnels + return nil +} + +func (m *Master) setupSecureProxy(user, keyfile string) { + loadTunnelsPrintError := func() { + if err := m.loadTunnels(user, keyfile); err != nil { + glog.Errorf("Failed to load SSH Tunnels: %v", err) + } + } + + // Sync loop for tunnels + // TODO: switch this to watch. + go func() { + for { + loadTunnelsPrintError() + + var sleep time.Duration + if len(m.tunnels) == 0 { + sleep = time.Second + } else { + sleep = time.Second * 120 + } + time.Sleep(sleep) + } + }() +} diff --git a/pkg/util/ssh.go b/pkg/util/ssh.go index 7760761b339..09b82c2d200 100644 --- a/pkg/util/ssh.go +++ b/pkg/util/ssh.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "io/ioutil" + "math/rand" "net" "os" @@ -31,15 +32,12 @@ import ( // TODO: Unit tests for this code, we can spin up a test SSH server with instructions here: // https://godoc.org/golang.org/x/crypto/ssh#ServerConn type SSHTunnel struct { - Config *ssh.ClientConfig - Host string - SSHPort int - LocalPort int - RemoteHost string - RemotePort int - running bool - sock net.Listener - client *ssh.Client + Config *ssh.ClientConfig + Host string + SSHPort string + running bool + sock net.Listener + client *ssh.Client } func (s *SSHTunnel) copyBytes(out io.Writer, in io.Reader) { @@ -48,7 +46,7 @@ func (s *SSHTunnel) copyBytes(out io.Writer, in io.Reader) { } } -func NewSSHTunnel(user, keyfile, host, remoteHost string, localPort, remotePort int) (*SSHTunnel, error) { +func NewSSHTunnel(user, keyfile, host string) (*SSHTunnel, error) { signer, err := MakePrivateKeySigner(keyfile) if err != nil { return nil, err @@ -58,44 +56,51 @@ func NewSSHTunnel(user, keyfile, host, remoteHost string, localPort, remotePort Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, } return &SSHTunnel{ - Config: &config, - Host: host, - SSHPort: 22, - LocalPort: localPort, - RemotePort: remotePort, - RemoteHost: remoteHost, + Config: &config, + Host: host, + SSHPort: "22", }, nil } func (s *SSHTunnel) Open() error { var err error - s.client, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", s.Host, s.SSHPort), s.Config) + s.client, err = ssh.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config) if err != nil { return err } - s.sock, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", s.LocalPort)) + return nil +} + +func (s *SSHTunnel) Dial(network, address string) (net.Conn, error) { + return s.client.Dial(network, address) +} + +func (s *SSHTunnel) Listen(remoteHost, localPort, remotePort string) error { + var err error + s.sock, err = net.Listen("tcp", net.JoinHostPort("localhost", localPort)) if err != nil { return err } s.running = true - return nil -} - -func (s *SSHTunnel) Listen() { for s.running { conn, err := s.sock.Accept() if err != nil { - glog.Errorf("Error listening for ssh tunnel to %s (%v)", s.RemoteHost, err) + if s.running { + glog.Errorf("Error listening for ssh tunnel to %s (%v)", remoteHost, err) + } else { + glog.V(4).Infof("Error listening for ssh tunnel to %s (%v), this is likely due to the tunnel shutting down.", remoteHost, err) + } continue } - if err := s.tunnel(conn); err != nil { + if err := s.tunnel(conn, remoteHost, remotePort); err != nil { glog.Errorf("Error starting tunnel: %v", err) } } + return nil } -func (s *SSHTunnel) tunnel(conn net.Conn) error { - tunnel, err := s.client.Dial("tcp", fmt.Sprintf("%s:%d", s.RemoteHost, s.RemotePort)) +func (s *SSHTunnel) tunnel(conn net.Conn, remoteHost, remotePort string) error { + tunnel, err := s.client.Dial("tcp", net.JoinHostPort(remoteHost, remotePort)) if err != nil { return err } @@ -104,13 +109,16 @@ func (s *SSHTunnel) tunnel(conn net.Conn) error { return nil } -func (s *SSHTunnel) Close() error { +func (s *SSHTunnel) StopListening() error { // TODO: try to shutdown copying here? s.running = false - // TODO: Aggregate errors and keep going? if err := s.sock.Close(); err != nil { return err } + return nil +} + +func (s *SSHTunnel) Close() error { if err := s.client.Close(); err != nil { return err } @@ -172,3 +180,94 @@ func MakePrivateKeySigner(key string) (ssh.Signer, error) { } return signer, nil } + +/* +if len(r.tunnels) == 0 { + list, err := listNodes() + if err != nil { + glog.Errorf("unexpected error making tunnels: %v", err) + return + } + tunnels, err := MakeNodeSSHTunnels(list) + if err != nil { + status := errToAPIStatus(err) + writeJSON(status.Code, r.codec, status, w) + httpCode = status.Code + return + } + r.tunnels = tunnels + } + // TODO: round robin here + tunnel := r.tunnels[0] + if err != nil { + status := errToAPIStatus(err) + writeJSON(status.Code, r.codec, status, w) + httpCode = status.Code + return + } + defer func() { + if err := tunnel.Close(); err != nil { + glog.Errorf("Error closing ssh tunnel: %v", err) + } + }() + if err := tunnel.Open(); err != nil { + status := errToAPIStatus(err) + writeJSON(status.Code, r.codec, status, w) + httpCode = status.Code + return + } +*/ + +type SSHTunnelEntry struct { + Address string + Tunnel *SSHTunnel +} + +type SSHTunnelList []SSHTunnelEntry + +func MakeSSHTunnels(user, keyfile string, addresses []string) (SSHTunnelList, error) { + tunnels := []SSHTunnelEntry{} + for ix := range addresses { + addr := addresses[ix] + tunnel, err := NewSSHTunnel(user, keyfile, addr) + if err != nil { + return nil, err + } + tunnels = append(tunnels, SSHTunnelEntry{addr, tunnel}) + } + return tunnels, nil +} + +func (l SSHTunnelList) Open() error { + for ix := range l { + if err := l[ix].Tunnel.Open(); err != nil { + return err + } + } + return nil +} + +func (l SSHTunnelList) Close() error { + for ix := range l { + if err := l[ix].Tunnel.Close(); err != nil { + return err + } + } + return nil +} + +func (l SSHTunnelList) Dial(network, addr string) (net.Conn, error) { + if len(l) == 0 { + return nil, fmt.Errorf("Empty tunnel list.") + } + return l[rand.Int()%len(l)].Tunnel.Dial(network, addr) +} + +func (l SSHTunnelList) Has(addr string) bool { + for ix := range l { + if l[ix].Address == addr { + return true + } + } + return false +}