diff --git a/pkg/util/ssh.go b/pkg/util/ssh.go index 8891f53d057..0af722b835f 100644 --- a/pkg/util/ssh.go +++ b/pkg/util/ssh.go @@ -142,21 +142,42 @@ func (s *SSHTunnel) Close() error { return nil } +// Interface to allow mocking of ssh.Dial, for testing SSH +type sshDialer interface { + Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) +} + +// Real implementation of sshDialer +type realSSHDialer struct{} + +func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + return ssh.Dial(network, addr, config) +} + // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on -// host along with any SSH-level error. -func RunSSHCommand(cmd, host string, signer ssh.Signer) (string, string, int, error) { +// host as specific user, along with any SSH-level error. +// If user=="", it will default (like SSH) to os.Getenv("USER") +func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) { + return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer) +} + +// Internal implementation of runSSHCommand, for testing +func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer) (string, string, int, error) { + if user == "" { + user = os.Getenv("USER") + } // Setup the config, dial the server, and open a session. config := &ssh.ClientConfig{ - User: os.Getenv("USER"), + User: user, Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, } - client, err := ssh.Dial("tcp", host, config) + client, err := dialer.Dial("tcp", host, config) if err != nil { - return "", "", 0, fmt.Errorf("error getting SSH client to host %s: '%v'", host, err) + return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err) } session, err := client.NewSession() if err != nil { - return "", "", 0, fmt.Errorf("error creating session to host %s: '%v'", host, err) + return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err) } defer session.Close() @@ -176,7 +197,7 @@ func RunSSHCommand(cmd, host string, signer ssh.Signer) (string, string, int, er } else { // Some other kind of error happened (e.g. an IOError); consider the // SSH unsuccessful. - err = fmt.Errorf("failed running `%s` on %s: '%v'", cmd, host, err) + err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err) } } return bout.String(), berr.String(), code, err diff --git a/pkg/util/ssh_test.go b/pkg/util/ssh_test.go index bef10285523..df53b648bb0 100644 --- a/pkg/util/ssh_test.go +++ b/pkg/util/ssh_test.go @@ -24,6 +24,9 @@ import ( "github.com/golang/glog" "golang.org/x/crypto/ssh" + "io" + "os" + "strings" ) type testSSHServer struct { @@ -159,3 +162,84 @@ func TestSSHTunnel(t *testing.T) { t.Errorf("unexpected error: %v", err) } } + +type mockSSHDialer struct { + network string + addr string + config *ssh.ClientConfig +} + +func (d *mockSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + d.network = network + d.addr = addr + d.config = config + return nil, fmt.Errorf("mock error from Dial") +} + +type mockSigner struct { +} + +func (s *mockSigner) PublicKey() ssh.PublicKey { + panic("mockSigner.PublicKey not implemented") +} + +func (s *mockSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { + panic("mockSigner.Sign not implemented") +} + +func TestSSHUser(t *testing.T) { + signer := &mockSigner{} + + table := []struct { + title string + user string + host string + signer ssh.Signer + command string + expectUser string + }{ + { + title: "all values provided", + user: "testuser", + host: "testhost", + signer: signer, + command: "uptime", + expectUser: "testuser", + }, + { + title: "empty user defaults to GetEnv(USER)", + user: "", + host: "testhost", + signer: signer, + command: "uptime", + expectUser: os.Getenv("USER"), + }, + } + + for _, item := range table { + dialer := &mockSSHDialer{} + + _, _, _, err := runSSHCommand(dialer, item.command, item.user, item.host, item.signer) + if err == nil { + t.Errorf("expected error (as mock returns error); did not get one") + } + errString := err.Error() + if !strings.HasPrefix(errString, fmt.Sprintf("error getting SSH client to %s@%s:", item.expectUser, item.host)) { + t.Errorf("unexpected error: %v", errString) + } + + if dialer.network != "tcp" { + t.Errorf("unexpected network: %v", dialer.network) + } + + if dialer.config.User != item.expectUser { + t.Errorf("unexpected user: %v", dialer.config.User) + } + if len(dialer.config.Auth) != 1 { + t.Errorf("unexpected auth: %v", dialer.config.Auth) + } + // (No way to test Auth - nothing exported?) + + } + +} diff --git a/test/e2e/util.go b/test/e2e/util.go index b87823472c1..0b3f133d459 100644 --- a/test/e2e/util.go +++ b/test/e2e/util.go @@ -1161,7 +1161,9 @@ func SSH(cmd, host, provider string) (string, string, int, error) { return "", "", 0, fmt.Errorf("error getting signer for provider %s: '%v'", provider, err) } - return util.RunSSHCommand(cmd, host, signer) + user := os.Getenv("KUBE_SSH_USER") + // RunSSHCommand will default to Getenv("USER") if user == "" + return util.RunSSHCommand(cmd, user, host, signer) } // getSigner returns an ssh.Signer for the provider ("gce", etc.) that can be