mirror of
https://github.com/rancher/rke.git
synced 2025-05-07 15:57:06 +00:00
Use SSH Agent
This defaults to using the SSH Agent or a passwordless key file. It also refactors the Dialer methods a bit to simplify and decouple the host object from the dialer.
This commit is contained in:
parent
b945968af8
commit
f0d1689889
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
|
.idea
|
||||||
/.dapper
|
/.dapper
|
||||||
/bin
|
/bin
|
||||||
/dist
|
/dist
|
||||||
|
116
hosts/dialer.go
116
hosts/dialer.go
@ -11,82 +11,92 @@ import (
|
|||||||
type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error), error)
|
type DialerFactory func(h *Host) (func(network, address string) (net.Conn, error), error)
|
||||||
|
|
||||||
type dialer struct {
|
type dialer struct {
|
||||||
host *Host
|
signer ssh.Signer
|
||||||
signer ssh.Signer
|
sshKeyString string
|
||||||
|
sshAddress string
|
||||||
|
sshPassphrase []byte
|
||||||
|
username string
|
||||||
|
netConn string
|
||||||
|
dockerSocket string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDialer(h *Host, kind string) (*dialer, error) {
|
||||||
|
dialer := &dialer{
|
||||||
|
sshAddress: fmt.Sprintf("%s:%s", h.Address, h.Port),
|
||||||
|
username: h.User,
|
||||||
|
dockerSocket: h.DockerSocket,
|
||||||
|
sshKeyString: h.SSHKey,
|
||||||
|
netConn: "unix",
|
||||||
|
sshPassphrase: []byte(h.SavedKeyPhrase),
|
||||||
|
}
|
||||||
|
|
||||||
|
if dialer.sshKeyString == "" {
|
||||||
|
dialer.sshKeyString = privateKeyPath(h.SSHKeyPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case "network", "health":
|
||||||
|
dialer.netConn = "tcp"
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dialer.dockerSocket) == 0 {
|
||||||
|
dialer.dockerSocket = "/var/run/docker.sock"
|
||||||
|
}
|
||||||
|
|
||||||
|
return dialer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SSHFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
func SSHFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
||||||
key, err := h.checkEncryptedKey()
|
dialer, err := newDialer(h, "docker")
|
||||||
if err != nil {
|
return dialer.Dial, err
|
||||||
return nil, fmt.Errorf("Failed to parse the private key: %v", err)
|
|
||||||
}
|
|
||||||
dialer := &dialer{
|
|
||||||
host: h,
|
|
||||||
signer: key,
|
|
||||||
}
|
|
||||||
return dialer.DialDocker, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func LocalConnFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
func LocalConnFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
||||||
key, err := h.checkEncryptedKey()
|
dialer, err := newDialer(h, "network")
|
||||||
if err != nil {
|
return dialer.Dial, err
|
||||||
return nil, fmt.Errorf("Failed to parse the private key: %v", err)
|
|
||||||
}
|
|
||||||
dialer := &dialer{
|
|
||||||
host: h,
|
|
||||||
signer: key,
|
|
||||||
}
|
|
||||||
return dialer.DialLocalConn, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dialer) DialDocker(network, addr string) (net.Conn, error) {
|
func (d *dialer) DialDocker(network, addr string) (net.Conn, error) {
|
||||||
sshAddr := fmt.Sprintf("%s:%s", d.host.Address, d.host.Port)
|
return d.Dial(network, addr)
|
||||||
// Build SSH client configuration
|
|
||||||
cfg, err := makeSSHConfig(d.host.User, d.signer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
|
||||||
}
|
|
||||||
// Establish connection with SSH server
|
|
||||||
conn, err := ssh.Dial("tcp", sshAddr, cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", sshAddr, err)
|
|
||||||
}
|
|
||||||
if len(d.host.DockerSocket) == 0 {
|
|
||||||
d.host.DockerSocket = "/var/run/docker.sock"
|
|
||||||
}
|
|
||||||
remote, err := conn.Dial("unix", d.host.DockerSocket)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Failed to dial to Docker socket: %v", err)
|
|
||||||
}
|
|
||||||
return remote, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dialer) DialLocalConn(network, addr string) (net.Conn, error) {
|
func (d *dialer) DialLocalConn(network, addr string) (net.Conn, error) {
|
||||||
sshAddr := fmt.Sprintf("%s:%s", d.host.Address, d.host.Port)
|
return d.Dial(network, addr)
|
||||||
// Build SSH client configuration
|
}
|
||||||
cfg, err := makeSSHConfig(d.host.User, d.signer)
|
|
||||||
|
func (d *dialer) Dial(network, addr string) (net.Conn, error) {
|
||||||
|
conn, err := d.getSSHTunnelConnection()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", d.sshAddress, err)
|
||||||
}
|
}
|
||||||
// Establish connection with SSH server
|
|
||||||
conn, err := ssh.Dial("tcp", sshAddr, cfg)
|
// Docker Socket....
|
||||||
if err != nil {
|
if d.netConn == "unix" {
|
||||||
return nil, fmt.Errorf("Failed to dial ssh using address [%s]: %v", sshAddr, err)
|
addr = d.dockerSocket
|
||||||
|
network = d.netConn
|
||||||
}
|
}
|
||||||
|
|
||||||
remote, err := conn.Dial(network, addr)
|
remote, err := conn.Dial(network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Failed to dial to Local Port [%d] on host [%s]: %v", d.host.LocalConnPort, d.host.Address, err)
|
return nil, fmt.Errorf("Failed to dial to %s: %v", addr, err)
|
||||||
}
|
}
|
||||||
return remote, err
|
return remote, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) {
|
func (d *dialer) getSSHTunnelConnection() (*ssh.Client, error) {
|
||||||
var factory DialerFactory
|
cfg, err := getSSHConfig(d.username, d.sshKeyString, d.sshPassphrase)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error configuring SSH: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if dialerFactory == nil {
|
// Establish connection with SSH server
|
||||||
|
return ssh.Dial("tcp", d.sshAddress, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Host) newHTTPClient(dialerFactory DialerFactory) (*http.Client, error) {
|
||||||
|
factory := dialerFactory
|
||||||
|
if factory == nil {
|
||||||
factory = SSHFactory
|
factory = SSHFactory
|
||||||
} else {
|
|
||||||
factory = dialerFactory
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer, err := factory(h)
|
dialer, err := factory(h)
|
||||||
|
@ -6,16 +6,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func LocalHealthcheckFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
func LocalHealthcheckFactory(h *Host) (func(network, address string) (net.Conn, error), error) {
|
||||||
dialer := &dialer{
|
dialer, err := newDialer(h, "health")
|
||||||
host: h,
|
return dialer.DialHealthcheckLocally, err
|
||||||
}
|
|
||||||
return dialer.DialHealthcheckLocally, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dialer) DialHealthcheckLocally(network, addr string) (net.Conn, error) {
|
func (d *dialer) DialHealthcheckLocally(network, addr string) (net.Conn, error) {
|
||||||
conn, err := net.Dial(network, addr)
|
conn, err := net.Dial(network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Failed to dial address [%s]: %v", d.host.Address, err)
|
return nil, fmt.Errorf("Failed to dial address [%s]: %v", addr, err)
|
||||||
}
|
}
|
||||||
return conn, err
|
return conn, err
|
||||||
}
|
}
|
||||||
|
@ -7,14 +7,14 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
"github.com/rancher/rke/docker"
|
"github.com/rancher/rke/docker"
|
||||||
"github.com/rancher/rke/log"
|
"github.com/rancher/rke/log"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
"golang.org/x/crypto/ssh/agent"
|
||||||
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -82,54 +82,37 @@ func parsePrivateKeyWithPassPhrase(keyBuff string, passphrase []byte) (ssh.Signe
|
|||||||
return ssh.ParsePrivateKeyWithPassphrase([]byte(keyBuff), passphrase)
|
return ssh.ParsePrivateKeyWithPassphrase([]byte(keyBuff), passphrase)
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeSSHConfig(user string, signer ssh.Signer) (*ssh.ClientConfig, error) {
|
func getSSHConfig(username, sshPrivateKeyString string, passphrase []byte) (*ssh.ClientConfig, error) {
|
||||||
config := ssh.ClientConfig{
|
config := &ssh.ClientConfig{
|
||||||
User: user,
|
User: username,
|
||||||
Auth: []ssh.AuthMethod{
|
|
||||||
ssh.PublicKeys(signer),
|
|
||||||
},
|
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return &config, nil
|
if sshAgentSock := os.Getenv("SSH_AUTH_SOCK"); sshAgentSock != "" {
|
||||||
|
sshAgent, err := net.Dial("unix", sshAgentSock)
|
||||||
|
if err != nil {
|
||||||
|
return config, fmt.Errorf("Cannot connect to SSH Auth socket %q: %s", sshAgentSock, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Auth = append(config.Auth, ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers))
|
||||||
|
|
||||||
|
logrus.Debugf("using %q SSH_AUTH_SOCK", sshAgentSock)
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := getPrivateKeySigner(sshPrivateKeyString, passphrase)
|
||||||
|
if err != nil {
|
||||||
|
return config, err
|
||||||
|
}
|
||||||
|
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
|
||||||
|
|
||||||
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Host) checkEncryptedKey() (ssh.Signer, error) {
|
func getPrivateKeySigner(sshPrivateKeyString string, passphrase []byte) (ssh.Signer, error) {
|
||||||
logrus.Debugf("[ssh] Checking private key")
|
key, err := parsePrivateKey(sshPrivateKeyString)
|
||||||
var err error
|
if err != nil && strings.Contains(err.Error(), "decode encrypted private keys") {
|
||||||
var key ssh.Signer
|
key, err = parsePrivateKeyWithPassPhrase(sshPrivateKeyString, passphrase)
|
||||||
if len(h.SSHKey) > 0 {
|
|
||||||
key, err = parsePrivateKey(h.SSHKey)
|
|
||||||
} else {
|
|
||||||
key, err = parsePrivateKey(privateKeyPath(h.SSHKeyPath))
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse encrypted key
|
|
||||||
if strings.Contains(err.Error(), "decode encrypted private keys") {
|
|
||||||
var passphrase []byte
|
|
||||||
if len(h.SavedKeyPhrase) == 0 {
|
|
||||||
fmt.Printf("Passphrase for Private SSH Key: ")
|
|
||||||
passphrase, err = terminal.ReadPassword(int(syscall.Stdin))
|
|
||||||
fmt.Printf("\n")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.SavedKeyPhrase = string(passphrase)
|
|
||||||
} else {
|
|
||||||
passphrase = []byte(h.SavedKeyPhrase)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(h.SSHKey) > 0 {
|
|
||||||
key, err = parsePrivateKeyWithPassPhrase(h.SSHKey, passphrase)
|
|
||||||
} else {
|
|
||||||
key, err = parsePrivateKeyWithPassPhrase(privateKeyPath(h.SSHKeyPath), passphrase)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return key, err
|
return key, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user