diff --git a/cluster.yml b/cluster.yml index 3681dab9..17d1fb47 100644 --- a/cluster.yml +++ b/cluster.yml @@ -194,3 +194,15 @@ ingress: cloud_provider: name: aws + +# Bastion/Jump host configuration +bastion_host: + address: x.x.x.x + user: ubuntu + port: 22 + ssh_key_path: /home/user/.ssh/bastion_rsa + # or + # ssh_key: |- + # -----BEGIN RSA PRIVATE KEY----- + # + # -----END RSA PRIVATE KEY----- diff --git a/cluster/cluster.go b/cluster/cluster.go index 52b5ad16..79aad7fd 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -185,6 +185,10 @@ func ParseCluster( if err != nil { return nil, fmt.Errorf("Failed to parse cloud config file: %v", err) } + // Create k8s wrap transport for bastion host + if len(c.BastionHost.Address) > 0 { + c.K8sWrapTransport = hosts.BastionHostWrapTransport(c.BastionHost) + } return c, nil } diff --git a/cluster/defaults.go b/cluster/defaults.go index 3e61111c..c6050ef0 100644 --- a/cluster/defaults.go +++ b/cluster/defaults.go @@ -53,6 +53,17 @@ func (c *Cluster) setClusterDefaults(ctx context.Context) { if len(c.PrefixPath) == 0 { c.PrefixPath = "/" } + // Set bastion/jump host defaults + if len(c.BastionHost.Address) > 0 { + if len(c.BastionHost.Port) == 0 { + c.BastionHost.Port = DefaultSSHPort + } + if len(c.BastionHost.SSHKeyPath) == 0 { + c.BastionHost.SSHKeyPath = c.SSHKeyPath + } + c.BastionHost.SSHAgentAuth = c.SSHAgentAuth + + } for i, host := range c.Nodes { if len(host.InternalAddress) == 0 { c.Nodes[i].InternalAddress = c.Nodes[i].Address diff --git a/cluster/hosts.go b/cluster/hosts.go index 5637f938..7f3a217f 100644 --- a/cluster/hosts.go +++ b/cluster/hosts.go @@ -72,7 +72,10 @@ func (c *Cluster) InvertIndexHosts() error { newHost.ToAddLabels[k] = v } newHost.IgnoreDockerVersion = c.IgnoreDockerVersion - + if c.BastionHost.Address != "" { + // Add the bastion host information to each host object + newHost.BastionHost = c.BastionHost + } for _, role := range host.Role { logrus.Debugf("Host: " + host.Address + " has role: " + role) switch role { diff --git a/cluster/network.go b/cluster/network.go index abee1b47..822da6e4 100644 --- a/cluster/network.go +++ b/cluster/network.go @@ -220,10 +220,13 @@ func (c *Cluster) CheckClusterPorts(ctx context.Context, currentCluster *Cluster if err := c.runServicePortChecks(ctx); err != nil { return err } - if c.K8sWrapTransport == nil { + // Skip kubeapi check if we are using custom k8s dialer or bastion/jump host + if c.K8sWrapTransport == nil && len(c.BastionHost.Address) == 0 { if err := c.checkKubeAPIPort(ctx); err != nil { return err } + } else { + log.Infof(ctx, "[network] Skipping kubeapi port check") } return c.removeTCPPortListeners(ctx) diff --git a/hosts/dialer.go b/hosts/dialer.go index 7bfadf5e..0f6a85d5 100644 --- a/hosts/dialer.go +++ b/hosts/dialer.go @@ -6,6 +6,8 @@ import ( "net/http" "time" + "github.com/rancher/rke/k8s" + "github.com/rancher/types/apis/management.cattle.io/v3" "golang.org/x/crypto/ssh" ) @@ -19,22 +21,37 @@ type dialer struct { signer ssh.Signer sshKeyString string sshAddress string - sshPassphrase []byte username string netConn string dockerSocket string useSSHAgentAuth bool + bastionDialer *dialer } func newDialer(h *Host, kind string) (*dialer, error) { + // Check for Bastion host connection + var bastionDialer *dialer + if len(h.BastionHost.Address) > 0 { + bastionDialer = &dialer{ + sshAddress: fmt.Sprintf("%s:%s", h.BastionHost.Address, h.BastionHost.Port), + username: h.BastionHost.User, + sshKeyString: h.BastionHost.SSHKey, + netConn: "tcp", + useSSHAgentAuth: h.SSHAgentAuth, + } + if bastionDialer.sshKeyString == "" { + bastionDialer.sshKeyString = privateKeyPath(h.BastionHost.SSHKeyPath) + } + } + dialer := &dialer{ sshAddress: fmt.Sprintf("%s:%s", h.Address, h.Port), username: h.User, dockerSocket: h.DockerSocket, sshKeyString: h.SSHKey, netConn: "unix", - sshPassphrase: []byte(h.SavedKeyPhrase), useSSHAgentAuth: h.SSHAgentAuth, + bastionDialer: bastionDialer, } if dialer.sshKeyString == "" { @@ -72,7 +89,13 @@ func (d *dialer) DialLocalConn(network, addr string) (net.Conn, error) { } func (d *dialer) Dial(network, addr string) (net.Conn, error) { - conn, err := d.getSSHTunnelConnection() + var conn *ssh.Client + var err error + if d.bastionDialer != nil { + conn, err = d.getBastionHostTunnelConn() + } else { + conn, err = d.getSSHTunnelConnection() + } if err != nil { return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", d.sshAddress, err) } @@ -91,11 +114,10 @@ func (d *dialer) Dial(network, addr string) (net.Conn, error) { } func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) { - cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshPassphrase, d.useSSHAgentAuth) + cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth) if err != nil { return nil, fmt.Errorf("Error configuring SSH: %v", err) } - // Establish connection with SSH server return ssh.Dial("tcp", d.sshAddress, cfg) } @@ -120,3 +142,50 @@ func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) }, }, nil } + +func (d *dialer) getBastionHostTunnelConn() (*ssh.Client, error) { + bastionCfg, err := getSSHConfig(d.bastionDialer.username, d.bastionDialer.sshKeyString, d.bastionDialer.useSSHAgentAuth) + if err != nil { + return nil, fmt.Errorf("Error configuring SSH for bastion host [%s]: %v", d.bastionDialer.sshAddress, err) + } + bastionClient, err := ssh.Dial("tcp", d.bastionDialer.sshAddress, bastionCfg) + if err != nil { + return nil, fmt.Errorf("Failed to connect to the bastion host [%s]: %v", d.bastionDialer.sshAddress, err) + } + conn, err := bastionClient.Dial(d.bastionDialer.netConn, d.sshAddress) + if err != nil { + return nil, fmt.Errorf("Failed to connect to the host [%s]: %v", d.sshAddress, err) + } + cfg, err := getSSHConfig(d.username, d.sshKeyString, d.useSSHAgentAuth) + if err != nil { + return nil, fmt.Errorf("Error configuring SSH for host [%s]: %v", d.sshAddress, err) + } + newClientConn, channels, sshRequest, err := ssh.NewClientConn(conn, d.sshAddress, cfg) + if err != nil { + return nil, fmt.Errorf("Failed to establish new ssh client conn [%s]: %v", d.sshAddress, err) + } + return ssh.NewClient(newClientConn, channels, sshRequest), nil +} + +func BastionHostWrapTransport(bastionHost v3.BastionHost) k8s.WrapTransport { + + bastionDialer := &dialer{ + sshAddress: fmt.Sprintf("%s:%s", bastionHost.Address, bastionHost.Port), + username: bastionHost.User, + sshKeyString: bastionHost.SSHKey, + netConn: "tcp", + useSSHAgentAuth: bastionHost.SSHAgentAuth, + } + + if bastionDialer.sshKeyString == "" { + bastionDialer.sshKeyString = privateKeyPath(bastionHost.SSHKeyPath) + } + return func(rt http.RoundTripper) http.RoundTripper { + if ht, ok := rt.(*http.Transport); ok { + ht.DialContext = nil + ht.DialTLS = nil + ht.Dial = bastionDialer.Dial + } + return rt + } +} diff --git a/hosts/hosts.go b/hosts/hosts.go index 309394f6..d440b4fb 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -36,6 +36,7 @@ type Host struct { DockerInfo types.Info UpdateWorker bool PrefixPath string + BastionHost v3.BastionHost } const ( diff --git a/hosts/tunnel.go b/hosts/tunnel.go index bd699fb3..eddbb24b 100644 --- a/hosts/tunnel.go +++ b/hosts/tunnel.go @@ -6,7 +6,6 @@ import ( "io/ioutil" "os" "path/filepath" - "strings" "net" @@ -83,7 +82,7 @@ func parsePrivateKeyWithPassPhrase(keyBuff string, passphrase []byte) (ssh.Signe return ssh.ParsePrivateKeyWithPassphrase([]byte(keyBuff), passphrase) } -func getSSHConfig(username, sshPrivateKeyString string, passphrase []byte, useAgentAuth bool) (*ssh.ClientConfig, error) { +func getSSHConfig(username, sshPrivateKeyString string, useAgentAuth bool) (*ssh.ClientConfig, error) { config := &ssh.ClientConfig{ User: username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), @@ -104,7 +103,7 @@ func getSSHConfig(username, sshPrivateKeyString string, passphrase []byte, useAg } } - signer, err := getPrivateKeySigner(sshPrivateKeyString, passphrase) + signer, err := parsePrivateKey(sshPrivateKeyString) if err != nil { return config, err } @@ -113,14 +112,6 @@ func getSSHConfig(username, sshPrivateKeyString string, passphrase []byte, useAg return config, nil } -func getPrivateKeySigner(sshPrivateKeyString string, passphrase []byte) (ssh.Signer, error) { - key, err := parsePrivateKey(sshPrivateKeyString) - if err != nil && strings.Contains(err.Error(), "decode encrypted private keys") { - key, err = parsePrivateKeyWithPassPhrase(sshPrivateKeyString, passphrase) - } - return key, err -} - func privateKeyPath(sshKeyPath string) string { if sshKeyPath[:2] == "~/" { sshKeyPath = filepath.Join(os.Getenv("HOME"), sshKeyPath[2:])