mirror of
https://github.com/rancher/rke.git
synced 2025-04-27 19:25:44 +00:00
Use SSH Agent
This defaults to using the SSH Agent or a passwordless key file. It also refactors the Dialer methods a bit to simplify and decouple the host object from the dialer.
This commit is contained in:
parent
b945968af8
commit
f0d1689889
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
.idea
|
||||
/.dapper
|
||||
/bin
|
||||
/dist
|
||||
|
116
hosts/dialer.go
116
hosts/dialer.go
@ -11,82 +11,92 @@ import (
|
||||
type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error), error)
|
||||
|
||||
type dialer struct {
|
||||
host *Host
|
||||
signer ssh.Signer
|
||||
signer ssh.Signer
|
||||
sshKeyString string
|
||||
sshAddress string
|
||||
sshPassphrase []byte
|
||||
username string
|
||||
netConn string
|
||||
dockerSocket string
|
||||
}
|
||||
|
||||
func newDialer(h *Host, kind string) (*dialer, error) {
|
||||
dialer := &dialer{
|
||||
sshAddress: fmt.Sprintf("%s:%s", h.Address, h.Port),
|
||||
username: h.User,
|
||||
dockerSocket: h.DockerSocket,
|
||||
sshKeyString: h.SSHKey,
|
||||
netConn: "unix",
|
||||
sshPassphrase: []byte(h.SavedKeyPhrase),
|
||||
}
|
||||
|
||||
if dialer.sshKeyString == "" {
|
||||
dialer.sshKeyString = privateKeyPath(h.SSHKeyPath)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "network", "health":
|
||||
dialer.netConn = "tcp"
|
||||
}
|
||||
|
||||
if len(dialer.dockerSocket) == 0 {
|
||||
dialer.dockerSocket = "/var/run/docker.sock"
|
||||
}
|
||||
|
||||
return dialer, nil
|
||||
}
|
||||
|
||||
func SSHFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
||||
key, err := h.checkEncryptedKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse the private key: %v", err)
|
||||
}
|
||||
dialer := &dialer{
|
||||
host: h,
|
||||
signer: key,
|
||||
}
|
||||
return dialer.DialDocker, nil
|
||||
dialer, err := newDialer(h, "docker")
|
||||
return dialer.Dial, err
|
||||
}
|
||||
|
||||
func LocalConnFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
||||
key, err := h.checkEncryptedKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to parse the private key: %v", err)
|
||||
}
|
||||
dialer := &dialer{
|
||||
host: h,
|
||||
signer: key,
|
||||
}
|
||||
return dialer.DialLocalConn, nil
|
||||
dialer, err := newDialer(h, "network")
|
||||
return dialer.Dial, err
|
||||
}
|
||||
|
||||
func (d *dialer) DialDocker(network, addr string) (net.Conn, error) {
|
||||
sshAddr := fmt.Sprintf("%s:%s", d.host.Address, d.host.Port)
|
||||
// Build SSH client configuration
|
||||
cfg, err := makeSSHConfig(d.host.User, d.signer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
||||
}
|
||||
// Establish connection with SSH server
|
||||
conn, err := ssh.Dial("tcp", sshAddr, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", sshAddr, err)
|
||||
}
|
||||
if len(d.host.DockerSocket) == 0 {
|
||||
d.host.DockerSocket = "/var/run/docker.sock"
|
||||
}
|
||||
remote, err := conn.Dial("unix", d.host.DockerSocket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to dial to Docker socket: %v", err)
|
||||
}
|
||||
return remote, err
|
||||
return d.Dial(network, addr)
|
||||
}
|
||||
|
||||
func (d *dialer) DialLocalConn(network, addr string) (net.Conn, error) {
|
||||
sshAddr := fmt.Sprintf("%s:%s", d.host.Address, d.host.Port)
|
||||
// Build SSH client configuration
|
||||
cfg, err := makeSSHConfig(d.host.User, d.signer)
|
||||
return d.Dial(network, addr)
|
||||
}
|
||||
|
||||
func (d *dialer) Dial(network, addr string) (net.Conn, error) {
|
||||
conn, err := d.getSSHTunnelConnection()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
||||
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", d.sshAddress, err)
|
||||
}
|
||||
// Establish connection with SSH server
|
||||
conn, err := ssh.Dial("tcp", sshAddr, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", sshAddr, err)
|
||||
|
||||
// Docker Socket....
|
||||
if d.netConn == "unix" {
|
||||
addr = d.dockerSocket
|
||||
network = d.netConn
|
||||
}
|
||||
|
||||
remote, err := conn.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to dial to Local Port [%d] on host [%s]: %v", d.host.LocalConnPort, d.host.Address, err)
|
||||
return nil, fmt.Errorf("Failed to dial to %s: %v", addr, err)
|
||||
}
|
||||
return remote, err
|
||||
}
|
||||
|
||||
func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) {
|
||||
var factory DialerFactory
|
||||
func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) {
|
||||
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshPassphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
||||
}
|
||||
|
||||
if dialerFactory == nil {
|
||||
// Establish connection with SSH server
|
||||
return ssh.Dial("tcp", d.sshAddress, cfg)
|
||||
}
|
||||
|
||||
func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) {
|
||||
factory := dialerFactory
|
||||
if factory == nil {
|
||||
factory = SSHFactory
|
||||
} else {
|
||||
factory = dialerFactory
|
||||
}
|
||||
|
||||
dialer, err := factory(h)
|
||||
|
@ -6,16 +6,14 @@ import (
|
||||
)
|
||||
|
||||
func LocalHealthcheckFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
||||
dialer := &dialer{
|
||||
host: h,
|
||||
}
|
||||
return dialer.DialHealthcheckLocally, nil
|
||||
dialer, err := newDialer(h, "health")
|
||||
return dialer.DialHealthcheckLocally, err
|
||||
}
|
||||
|
||||
func (d *dialer) DialHealthcheckLocally(network, addr string) (net.Conn, error) {
|
||||
conn, err := net.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to dial address [%s]: %v", d.host.Address, err)
|
||||
return nil, fmt.Errorf("Failed to dial address [%s]: %v", addr, err)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
@ -7,14 +7,14 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/rancher/rke/docker"
|
||||
"github.com/rancher/rke/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -82,54 +82,37 @@ func parsePrivateKeyWithPassPhrase(keyBuff string, passphrase []byte) (ssh.Signe
|
||||
return ssh.ParsePrivateKeyWithPassphrase([]byte(keyBuff), passphrase)
|
||||
}
|
||||
|
||||
func makeSSHConfig(user string, signer ssh.Signer) (*ssh.ClientConfig, error) {
|
||||
config := ssh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(signer),
|
||||
},
|
||||
func getSSHConfig(username, sshPrivateKeyString string, passphrase []byte) (*ssh.ClientConfig, error) {
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
if sshAgentSock := os.Getenv("SSH_AUTH_SOCK"); sshAgentSock != "" {
|
||||
sshAgent, err := net.Dial("unix", sshAgentSock)
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("Cannot connect to SSH Auth socket %q: %s", sshAgentSock, err)
|
||||
}
|
||||
|
||||
config.Auth = append(config.Auth, ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers))
|
||||
|
||||
logrus.Debugf("using %q SSH_AUTH_SOCK", sshAgentSock)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
signer, err := getPrivateKeySigner(sshPrivateKeyString, passphrase)
|
||||
if err != nil {
|
||||
return config, err
|
||||
}
|
||||
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (h *Host) checkEncryptedKey() (ssh.Signer, error) {
|
||||
logrus.Debugf("[ssh] Checking private key")
|
||||
var err error
|
||||
var key ssh.Signer
|
||||
if len(h.SSHKey) > 0 {
|
||||
key, err = parsePrivateKey(h.SSHKey)
|
||||
} else {
|
||||
key, err = parsePrivateKey(privateKeyPath(h.SSHKeyPath))
|
||||
}
|
||||
if err == nil {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// parse encrypted key
|
||||
if strings.Contains(err.Error(), "decode encrypted private keys") {
|
||||
var passphrase []byte
|
||||
if len(h.SavedKeyPhrase) == 0 {
|
||||
fmt.Printf("Passphrase for Private SSH Key: ")
|
||||
passphrase, err = terminal.ReadPassword(int(syscall.Stdin))
|
||||
fmt.Printf("\n")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.SavedKeyPhrase = string(passphrase)
|
||||
} else {
|
||||
passphrase = []byte(h.SavedKeyPhrase)
|
||||
}
|
||||
|
||||
if len(h.SSHKey) > 0 {
|
||||
key, err = parsePrivateKeyWithPassPhrase(h.SSHKey, passphrase)
|
||||
} else {
|
||||
key, err = parsePrivateKeyWithPassPhrase(privateKeyPath(h.SSHKeyPath), passphrase)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func getPrivateKeySigner(sshPrivateKeyString string, passphrase []byte) (ssh.Signer, error) {
|
||||
key, err := parsePrivateKey(sshPrivateKeyString)
|
||||
if err != nil && strings.Contains(err.Error(), "decode encrypted private keys") {
|
||||
key, err = parsePrivateKeyWithPassPhrase(sshPrivateKeyString, passphrase)
|
||||
}
|
||||
return key, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user