Merge pull request #9905 from justinsb/aws_ssh_user

Don't assume we always SSH as the current user
This commit is contained in:
Jeff Lowdermilk 2015-06-22 10:37:07 -07:00
commit 0f59847cc5
3 changed files with 117 additions and 8 deletions

View File

@ -142,21 +142,44 @@ func (s *SSHTunnel) Close() error {
return nil
}
// Interface to allow mocking of ssh.Dial, for testing SSH
type sshDialer interface {
Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
}
// Real implementation of sshDialer
type realSSHDialer struct{}
var _ sshDialer = &realSSHDialer{}
func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
return ssh.Dial(network, addr, config)
}
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
// host along with any SSH-level error.
func RunSSHCommand(cmd, host string, signer ssh.Signer) (string, string, int, error) {
// host as specific user, along with any SSH-level error.
// If user=="", it will default (like SSH) to os.Getenv("USER")
func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer)
}
// Internal implementation of runSSHCommand, for testing
func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
if user == "" {
user = os.Getenv("USER")
}
// Setup the config, dial the server, and open a session.
config := &ssh.ClientConfig{
User: os.Getenv("USER"),
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
}
client, err := ssh.Dial("tcp", host, config)
client, err := dialer.Dial("tcp", host, config)
if err != nil {
return "", "", 0, fmt.Errorf("error getting SSH client to host %s: '%v'", host, err)
return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err)
}
session, err := client.NewSession()
if err != nil {
return "", "", 0, fmt.Errorf("error creating session to host %s: '%v'", host, err)
return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err)
}
defer session.Close()
@ -176,7 +199,7 @@ func RunSSHCommand(cmd, host string, signer ssh.Signer) (string, string, int, er
} else {
// Some other kind of error happened (e.g. an IOError); consider the
// SSH unsuccessful.
err = fmt.Errorf("failed running `%s` on %s: '%v'", cmd, host, err)
err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err)
}
}
return bout.String(), berr.String(), code, err

View File

@ -24,6 +24,9 @@ import (
"github.com/golang/glog"
"golang.org/x/crypto/ssh"
"io"
"os"
"strings"
)
type testSSHServer struct {
@ -159,3 +162,84 @@ func TestSSHTunnel(t *testing.T) {
t.Errorf("unexpected error: %v", err)
}
}
type mockSSHDialer struct {
network string
addr string
config *ssh.ClientConfig
}
func (d *mockSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
d.network = network
d.addr = addr
d.config = config
return nil, fmt.Errorf("mock error from Dial")
}
type mockSigner struct {
}
func (s *mockSigner) PublicKey() ssh.PublicKey {
panic("mockSigner.PublicKey not implemented")
}
func (s *mockSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
panic("mockSigner.Sign not implemented")
}
func TestSSHUser(t *testing.T) {
signer := &mockSigner{}
table := []struct {
title string
user string
host string
signer ssh.Signer
command string
expectUser string
}{
{
title: "all values provided",
user: "testuser",
host: "testhost",
signer: signer,
command: "uptime",
expectUser: "testuser",
},
{
title: "empty user defaults to GetEnv(USER)",
user: "",
host: "testhost",
signer: signer,
command: "uptime",
expectUser: os.Getenv("USER"),
},
}
for _, item := range table {
dialer := &mockSSHDialer{}
_, _, _, err := runSSHCommand(dialer, item.command, item.user, item.host, item.signer)
if err == nil {
t.Errorf("expected error (as mock returns error); did not get one")
}
errString := err.Error()
if !strings.HasPrefix(errString, fmt.Sprintf("error getting SSH client to %s@%s:", item.expectUser, item.host)) {
t.Errorf("unexpected error: %v", errString)
}
if dialer.network != "tcp" {
t.Errorf("unexpected network: %v", dialer.network)
}
if dialer.config.User != item.expectUser {
t.Errorf("unexpected user: %v", dialer.config.User)
}
if len(dialer.config.Auth) != 1 {
t.Errorf("unexpected auth: %v", dialer.config.Auth)
}
// (No way to test Auth - nothing exported?)
}
}

View File

@ -1161,7 +1161,9 @@ func SSH(cmd, host, provider string) (string, string, int, error) {
return "", "", 0, fmt.Errorf("error getting signer for provider %s: '%v'", provider, err)
}
return util.RunSSHCommand(cmd, host, signer)
user := os.Getenv("KUBE_SSH_USER")
// RunSSHCommand will default to Getenv("USER") if user == ""
return util.RunSSHCommand(cmd, user, host, signer)
}
// getSigner returns an ssh.Signer for the provider ("gce", etc.) that can be