1
0
mirror of https://github.com/rancher/rke.git synced 2025-04-27 11:21:08 +00:00

Use SSH Agent

This defaults to using the SSH Agent or a passwordless key file.
It also refactors the Dialer methods a bit to simplify and decouple
the host object from the dialer.
This commit is contained in:
Bill Maxwell 2018-02-26 14:27:51 -07:00 committed by Darren Shepherd
parent b945968af8
commit f0d1689889
4 changed files with 95 additions and 103 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
/.dapper
/bin
/dist

View File

@ -11,82 +11,92 @@ import (
type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error), error)
type dialer struct {
host *Host
signer ssh.Signer
signer ssh.Signer
sshKeyString string
sshAddress string
sshPassphrase []byte
username string
netConn string
dockerSocket string
}
func newDialer(h *Host, kind string) (*dialer, error) {
dialer := &dialer{
sshAddress: fmt.Sprintf("%s:%s", h.Address, h.Port),
username: h.User,
dockerSocket: h.DockerSocket,
sshKeyString: h.SSHKey,
netConn: "unix",
sshPassphrase: []byte(h.SavedKeyPhrase),
}
if dialer.sshKeyString == "" {
dialer.sshKeyString = privateKeyPath(h.SSHKeyPath)
}
switch kind {
case "network", "health":
dialer.netConn = "tcp"
}
if len(dialer.dockerSocket) == 0 {
dialer.dockerSocket = "/var/run/docker.sock"
}
return dialer, nil
}
func SSHFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
key, err := h.checkEncryptedKey()
if err != nil {
return nil, fmt.Errorf("Failed to parse the private key: %v", err)
}
dialer := &dialer{
host: h,
signer: key,
}
return dialer.DialDocker, nil
dialer, err := newDialer(h, "docker")
return dialer.Dial, err
}
func LocalConnFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
key, err := h.checkEncryptedKey()
if err != nil {
return nil, fmt.Errorf("Failed to parse the private key: %v", err)
}
dialer := &dialer{
host: h,
signer: key,
}
return dialer.DialLocalConn, nil
dialer, err := newDialer(h, "network")
return dialer.Dial, err
}
func (d *dialer) DialDocker(network, addr string) (net.Conn, error) {
sshAddr := fmt.Sprintf("%s:%s", d.host.Address, d.host.Port)
// Build SSH client configuration
cfg, err := makeSSHConfig(d.host.User, d.signer)
if err != nil {
return nil, fmt.Errorf("Error configuring SSH: %v", err)
}
// Establish connection with SSH server
conn, err := ssh.Dial("tcp", sshAddr, cfg)
if err != nil {
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", sshAddr, err)
}
if len(d.host.DockerSocket) == 0 {
d.host.DockerSocket = "/var/run/docker.sock"
}
remote, err := conn.Dial("unix", d.host.DockerSocket)
if err != nil {
return nil, fmt.Errorf("Failed to dial to Docker socket: %v", err)
}
return remote, err
return d.Dial(network, addr)
}
func (d *dialer) DialLocalConn(network, addr string) (net.Conn, error) {
sshAddr := fmt.Sprintf("%s:%s", d.host.Address, d.host.Port)
// Build SSH client configuration
cfg, err := makeSSHConfig(d.host.User, d.signer)
return d.Dial(network, addr)
}
func (d *dialer) Dial(network, addr string) (net.Conn, error) {
conn, err := d.getSSHTunnelConnection()
if err != nil {
return nil, fmt.Errorf("Error configuring SSH: %v", err)
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", d.sshAddress, err)
}
// Establish connection with SSH server
conn, err := ssh.Dial("tcp", sshAddr, cfg)
if err != nil {
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", sshAddr, err)
// Docker Socket....
if d.netConn == "unix" {
addr = d.dockerSocket
network = d.netConn
}
remote, err := conn.Dial(network, addr)
if err != nil {
return nil, fmt.Errorf("Failed to dial to Local Port [%d] on host [%s]: %v", d.host.LocalConnPort, d.host.Address, err)
return nil, fmt.Errorf("Failed to dial to %s: %v", addr, err)
}
return remote, err
}
func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) {
var factory DialerFactory
func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) {
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshPassphrase)
if err != nil {
return nil, fmt.Errorf("Error configuring SSH: %v", err)
}
if dialerFactory == nil {
// Establish connection with SSH server
return ssh.Dial("tcp", d.sshAddress, cfg)
}
func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) {
factory := dialerFactory
if factory == nil {
factory = SSHFactory
} else {
factory = dialerFactory
}
dialer, err := factory(h)

View File

@ -6,16 +6,14 @@ import (
)
func LocalHealthcheckFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
dialer := &dialer{
host: h,
}
return dialer.DialHealthcheckLocally, nil
dialer, err := newDialer(h, "health")
return dialer.DialHealthcheckLocally, err
}
func (d *dialer) DialHealthcheckLocally(network, addr string) (net.Conn, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, fmt.Errorf("Failed to dial address [%s]: %v", d.host.Address, err)
return nil, fmt.Errorf("Failed to dial address [%s]: %v", addr, err)
}
return conn, err
}

View File

@ -7,14 +7,14 @@ import (
"os"
"path/filepath"
"strings"
"syscall"
"github.com/docker/docker/client"
"github.com/rancher/rke/docker"
"github.com/rancher/rke/log"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
"golang.org/x/crypto/ssh/agent"
"net"
)
const (
@ -82,54 +82,37 @@ func parsePrivateKeyWithPassPhrase(keyBuff string, passphrase []byte) (ssh.Signe
return ssh.ParsePrivateKeyWithPassphrase([]byte(keyBuff), passphrase)
}
func makeSSHConfig(user string, signer ssh.Signer) (*ssh.ClientConfig, error) {
config := ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
func getSSHConfig(username, sshPrivateKeyString string, passphrase []byte) (*ssh.ClientConfig, error) {
config := &ssh.ClientConfig{
User: username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
return &config, nil
if sshAgentSock := os.Getenv("SSH_AUTH_SOCK"); sshAgentSock != "" {
sshAgent, err := net.Dial("unix", sshAgentSock)
if err != nil {
return config, fmt.Errorf("Cannot connect to SSH Auth socket %q: %s", sshAgentSock, err)
}
config.Auth = append(config.Auth, ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers))
logrus.Debugf("using %q SSH_AUTH_SOCK", sshAgentSock)
return config, nil
}
signer, err := getPrivateKeySigner(sshPrivateKeyString, passphrase)
if err != nil {
return config, err
}
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
return config, nil
}
func (h *Host) checkEncryptedKey() (ssh.Signer, error) {
logrus.Debugf("[ssh] Checking private key")
var err error
var key ssh.Signer
if len(h.SSHKey) > 0 {
key, err = parsePrivateKey(h.SSHKey)
} else {
key, err = parsePrivateKey(privateKeyPath(h.SSHKeyPath))
}
if err == nil {
return key, nil
}
// parse encrypted key
if strings.Contains(err.Error(), "decode encrypted private keys") {
var passphrase []byte
if len(h.SavedKeyPhrase) == 0 {
fmt.Printf("Passphrase for Private SSH Key: ")
passphrase, err = terminal.ReadPassword(int(syscall.Stdin))
fmt.Printf("\n")
if err != nil {
return nil, err
}
h.SavedKeyPhrase = string(passphrase)
} else {
passphrase = []byte(h.SavedKeyPhrase)
}
if len(h.SSHKey) > 0 {
key, err = parsePrivateKeyWithPassPhrase(h.SSHKey, passphrase)
} else {
key, err = parsePrivateKeyWithPassPhrase(privateKeyPath(h.SSHKeyPath), passphrase)
}
if err != nil {
return nil, err
}
func getPrivateKeySigner(sshPrivateKeyString string, passphrase []byte) (ssh.Signer, error) {
key, err := parsePrivateKey(sshPrivateKeyString)
if err != nil && strings.Contains(err.Error(), "decode encrypted private keys") {
key, err = parsePrivateKeyWithPassPhrase(sshPrivateKeyString, passphrase)
}
return key, err
}