diff --git a/cluster/hosts.go b/cluster/hosts.go index ded06248..ee0f30d3 100644 --- a/cluster/hosts.go +++ b/cluster/hosts.go @@ -2,28 +2,37 @@ package cluster import ( "fmt" + "os" + "strings" + "syscall" "github.com/rancher/rke/hosts" "github.com/rancher/rke/pki" "github.com/rancher/rke/services" "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" ) func (c *Cluster) TunnelHosts() error { + key, err := checkEncryptedKey() + if err != nil { + return fmt.Errorf("Failed to parse the private key: %v", err) + } for i := range c.EtcdHosts { - err := c.EtcdHosts[i].TunnelUp() + err := c.EtcdHosts[i].TunnelUp(key) if err != nil { return fmt.Errorf("Failed to set up SSH tunneling for Etcd hosts: %v", err) } } for i := range c.ControlPlaneHosts { - err := c.ControlPlaneHosts[i].TunnelUp() + err := c.ControlPlaneHosts[i].TunnelUp(key) if err != nil { return fmt.Errorf("Failed to set up SSH tunneling for Control hosts: %v", err) } } for i := range c.WorkerHosts { - err := c.WorkerHosts[i].TunnelUp() + err := c.WorkerHosts[i].TunnelUp(key) if err != nil { return fmt.Errorf("Failed to set up SSH tunneling for Worker hosts: %v", err) } @@ -75,3 +84,29 @@ func (c *Cluster) SetUpHosts() error { } return nil } + +func checkEncryptedKey() (ssh.Signer, error) { + logrus.Infof("[ssh] Checking private key") + key, err := hosts.ParsePrivateKey(privateKeyPath()) + if err != nil { + if strings.Contains(err.Error(), "decode encrypted private keys") { + fmt.Printf("Passphrase for Private SSH Key: ") + passphrase, err := terminal.ReadPassword(int(syscall.Stdin)) + fmt.Printf("\n") + if err != nil { + return nil, err + } + key, err = hosts.ParsePrivateKeyWithPassPhrase(privateKeyPath(), passphrase) + if err != nil { + return nil, err + } + } else { + return nil, err + } + } + return key, nil +} + +func privateKeyPath() string { + return os.Getenv("HOME") + "/.ssh/id_rsa" +} diff --git a/hosts/dialer.go b/hosts/dialer.go index dcbbbfd7..4bad4df0 100644 --- a/hosts/dialer.go +++ b/hosts/dialer.go @@ -5,7 +5,6 @@ import ( "io/ioutil" "net" "net/http" - "os" "github.com/docker/docker/client" "github.com/sirupsen/logrus" @@ -13,7 +12,8 @@ import ( ) type dialer struct { - host *Host + host *Host + signer ssh.Signer } const ( @@ -23,7 +23,7 @@ const ( func (d *dialer) Dial(network, addr string) (net.Conn, error) { sshAddr := d.host.IP + ":22" // Build SSH client configuration - cfg, err := makeSSHConfig(d.host.User) + cfg, err := makeSSHConfig(d.host.User, d.signer) if err != nil { logrus.Fatalf("Error configuring SSH: %v", err) } @@ -42,11 +42,12 @@ func (d *dialer) Dial(network, addr string) (net.Conn, error) { return remote, err } -func (h *Host) TunnelUp() error { +func (h *Host) TunnelUp(signer ssh.Signer) error { logrus.Infof("[ssh] Start tunnel for host [%s]", h.AdvertisedHostname) dialer := &dialer{ - host: h, + host: h, + signer: signer, } httpClient := &http.Client{ Transport: &http.Transport{ @@ -64,26 +65,21 @@ func (h *Host) TunnelUp() error { return nil } -func privateKeyPath() string { - return os.Getenv("HOME") + "/.ssh/id_rsa" -} - -// Get private key for ssh authentication -func parsePrivateKey(keyPath string) (ssh.Signer, error) { +func ParsePrivateKey(keyPath string) (ssh.Signer, error) { buff, _ := ioutil.ReadFile(keyPath) return ssh.ParsePrivateKey(buff) } -func makeSSHConfig(user string) (*ssh.ClientConfig, error) { - key, err := parsePrivateKey(privateKeyPath()) - if err != nil { - return nil, err - } +func ParsePrivateKeyWithPassPhrase(keyPath string, passphrase []byte) (ssh.Signer, error) { + buff, _ := ioutil.ReadFile(keyPath) + return ssh.ParsePrivateKeyWithPassphrase(buff, passphrase) +} +func makeSSHConfig(user string, signer ssh.Signer) (*ssh.ClientConfig, error) { config := ssh.ClientConfig{ User: user, Auth: []ssh.AuthMethod{ - ssh.PublicKeys(key), + ssh.PublicKeys(signer), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }