diff --git a/.gitignore b/.gitignore index 78dca3ea..9fd6ea5f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea /.dapper /bin /dist diff --git a/hosts/dialer.go b/hosts/dialer.go index 3bded003..735c150d 100644 --- a/hosts/dialer.go +++ b/hosts/dialer.go @@ -11,82 +11,92 @@ import ( type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error), error) type dialer struct { - host *Host - signer ssh.Signer + signer ssh.Signer + sshKeyString string + sshAddress string + sshPassphrase []byte + username string + netConn string + dockerSocket string +} + +func newDialer(h *Host, kind string) (*dialer, error) { + dialer := &dialer{ + sshAddress: fmt.Sprintf("%s:%s", h.Address, h.Port), + username: h.User, + dockerSocket: h.DockerSocket, + sshKeyString: h.SSHKey, + netConn: "unix", + sshPassphrase: []byte(h.SavedKeyPhrase), + } + + if dialer.sshKeyString == "" { + dialer.sshKeyString = privateKeyPath(h.SSHKeyPath) + } + + switch kind { + case "network", "health": + dialer.netConn = "tcp" + } + + if len(dialer.dockerSocket) == 0 { + dialer.dockerSocket = "/var/run/docker.sock" + } + + return dialer, nil } func SSHFactory(h *Host) (func(network, address string) (net.Conn, error), error) { - key, err := h.checkEncryptedKey() - if err != nil { - return nil, fmt.Errorf("Failed to parse the private key: %v", err) - } - dialer := &dialer{ - host: h, - signer: key, - } - return dialer.DialDocker, nil + dialer, err := newDialer(h, "docker") + return dialer.Dial, err } func LocalConnFactory(h *Host) (func(network, address string) (net.Conn, error), error) { - key, err := h.checkEncryptedKey() - if err != nil { - return nil, fmt.Errorf("Failed to parse the private key: %v", err) - } - dialer := &dialer{ - host: h, - signer: key, - } - return dialer.DialLocalConn, nil + dialer, err := newDialer(h, "network") + return dialer.Dial, err } func (d *dialer) DialDocker(network, addr string) (net.Conn, error) { - sshAddr := fmt.Sprintf("%s:%s", d.host.Address, d.host.Port) - // Build SSH client configuration - cfg, err := makeSSHConfig(d.host.User, d.signer) - if err != nil { - return nil, fmt.Errorf("Error configuring SSH: %v", err) - } - // Establish connection with SSH server - conn, err := ssh.Dial("tcp", sshAddr, cfg) - if err != nil { - return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", sshAddr, err) - } - if len(d.host.DockerSocket) == 0 { - d.host.DockerSocket = "/var/run/docker.sock" - } - remote, err := conn.Dial("unix", d.host.DockerSocket) - if err != nil { - return nil, fmt.Errorf("Failed to dial to Docker socket: %v", err) - } - return remote, err + return d.Dial(network, addr) } func (d *dialer) DialLocalConn(network, addr string) (net.Conn, error) { - sshAddr := fmt.Sprintf("%s:%s", d.host.Address, d.host.Port) - // Build SSH client configuration - cfg, err := makeSSHConfig(d.host.User, d.signer) + return d.Dial(network, addr) +} + +func (d *dialer) Dial(network, addr string) (net.Conn, error) { + conn, err := d.getSSHTunnelConnection() if err != nil { - return nil, fmt.Errorf("Error configuring SSH: %v", err) + return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", d.sshAddress, err) } - // Establish connection with SSH server - conn, err := ssh.Dial("tcp", sshAddr, cfg) - if err != nil { - return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", sshAddr, err) + + // Docker Socket.... + if d.netConn == "unix" { + addr = d.dockerSocket + network = d.netConn } + remote, err := conn.Dial(network, addr) if err != nil { - return nil, fmt.Errorf("Failed to dial to Local Port [%d] on host [%s]: %v", d.host.LocalConnPort, d.host.Address, err) + return nil, fmt.Errorf("Failed to dial to %s: %v", addr, err) } return remote, err } -func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) { - var factory DialerFactory +func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) { + cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshPassphrase) + if err != nil { + return nil, fmt.Errorf("Error configuring SSH: %v", err) + } - if dialerFactory == nil { + // Establish connection with SSH server + return ssh.Dial("tcp", d.sshAddress, cfg) +} + +func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) { + factory := dialerFactory + if factory == nil { factory = SSHFactory - } else { - factory = dialerFactory } dialer, err := factory(h) diff --git a/hosts/local.go b/hosts/local.go index 88f3e7c8..b34c5864 100644 --- a/hosts/local.go +++ b/hosts/local.go @@ -6,16 +6,14 @@ import ( ) func LocalHealthcheckFactory(h *Host) (func(network, address string) (net.Conn, error), error) { - dialer := &dialer{ - host: h, - } - return dialer.DialHealthcheckLocally, nil + dialer, err := newDialer(h, "health") + return dialer.DialHealthcheckLocally, err } func (d *dialer) DialHealthcheckLocally(network, addr string) (net.Conn, error) { conn, err := net.Dial(network, addr) if err != nil { - return nil, fmt.Errorf("Failed to dial address [%s]: %v", d.host.Address, err) + return nil, fmt.Errorf("Failed to dial address [%s]: %v", addr, err) } return conn, err } diff --git a/hosts/tunnel.go b/hosts/tunnel.go index 859ab829..5eeff356 100644 --- a/hosts/tunnel.go +++ b/hosts/tunnel.go @@ -7,14 +7,14 @@ import ( "os" "path/filepath" "strings" - "syscall" "github.com/docker/docker/client" "github.com/rancher/rke/docker" "github.com/rancher/rke/log" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" + "golang.org/x/crypto/ssh/agent" + "net" ) const ( @@ -82,54 +82,37 @@ func parsePrivateKeyWithPassPhrase(keyBuff string, passphrase []byte) (ssh.Signe return ssh.ParsePrivateKeyWithPassphrase([]byte(keyBuff), passphrase) } -func makeSSHConfig(user string, signer ssh.Signer) (*ssh.ClientConfig, error) { - config := ssh.ClientConfig{ - User: user, - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, +func getSSHConfig(username, sshPrivateKeyString string, passphrase []byte) (*ssh.ClientConfig, error) { + config := &ssh.ClientConfig{ + User: username, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - return &config, nil + if sshAgentSock := os.Getenv("SSH_AUTH_SOCK"); sshAgentSock != "" { + sshAgent, err := net.Dial("unix", sshAgentSock) + if err != nil { + return config, fmt.Errorf("Cannot connect to SSH Auth socket %q: %s", sshAgentSock, err) + } + + config.Auth = append(config.Auth, ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)) + + logrus.Debugf("using %q SSH_AUTH_SOCK", sshAgentSock) + return config, nil + } + + signer, err := getPrivateKeySigner(sshPrivateKeyString, passphrase) + if err != nil { + return config, err + } + config.Auth = append(config.Auth, ssh.PublicKeys(signer)) + + return config, nil } -func (h *Host) checkEncryptedKey() (ssh.Signer, error) { - logrus.Debugf("[ssh] Checking private key") - var err error - var key ssh.Signer - if len(h.SSHKey) > 0 { - key, err = parsePrivateKey(h.SSHKey) - } else { - key, err = parsePrivateKey(privateKeyPath(h.SSHKeyPath)) - } - if err == nil { - return key, nil - } - - // parse encrypted key - if strings.Contains(err.Error(), "decode encrypted private keys") { - var passphrase []byte - if len(h.SavedKeyPhrase) == 0 { - fmt.Printf("Passphrase for Private SSH Key: ") - passphrase, err = terminal.ReadPassword(int(syscall.Stdin)) - fmt.Printf("\n") - if err != nil { - return nil, err - } - h.SavedKeyPhrase = string(passphrase) - } else { - passphrase = []byte(h.SavedKeyPhrase) - } - - if len(h.SSHKey) > 0 { - key, err = parsePrivateKeyWithPassPhrase(h.SSHKey, passphrase) - } else { - key, err = parsePrivateKeyWithPassPhrase(privateKeyPath(h.SSHKeyPath), passphrase) - } - if err != nil { - return nil, err - } +func getPrivateKeySigner(sshPrivateKeyString string, passphrase []byte) (ssh.Signer, error) { + key, err := parsePrivateKey(sshPrivateKeyString) + if err != nil && strings.Contains(err.Error(), "decode encrypted private keys") { + key, err = parsePrivateKeyWithPassPhrase(sshPrivateKeyString, passphrase) } return key, err }