mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-09 03:57:41 +00:00
Merge pull request #9905 from justinsb/aws_ssh_user
Don't assume we always SSH as the current user
This commit is contained in:
commit
0f59847cc5
@ -142,21 +142,44 @@ func (s *SSHTunnel) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Interface to allow mocking of ssh.Dial, for testing SSH
|
||||
type sshDialer interface {
|
||||
Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
|
||||
}
|
||||
|
||||
// Real implementation of sshDialer
|
||||
type realSSHDialer struct{}
|
||||
|
||||
var _ sshDialer = &realSSHDialer{}
|
||||
|
||||
func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||
return ssh.Dial(network, addr, config)
|
||||
}
|
||||
|
||||
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
|
||||
// host along with any SSH-level error.
|
||||
func RunSSHCommand(cmd, host string, signer ssh.Signer) (string, string, int, error) {
|
||||
// host as specific user, along with any SSH-level error.
|
||||
// If user=="", it will default (like SSH) to os.Getenv("USER")
|
||||
func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
|
||||
return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer)
|
||||
}
|
||||
|
||||
// Internal implementation of runSSHCommand, for testing
|
||||
func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
|
||||
if user == "" {
|
||||
user = os.Getenv("USER")
|
||||
}
|
||||
// Setup the config, dial the server, and open a session.
|
||||
config := &ssh.ClientConfig{
|
||||
User: os.Getenv("USER"),
|
||||
User: user,
|
||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
||||
}
|
||||
client, err := ssh.Dial("tcp", host, config)
|
||||
client, err := dialer.Dial("tcp", host, config)
|
||||
if err != nil {
|
||||
return "", "", 0, fmt.Errorf("error getting SSH client to host %s: '%v'", host, err)
|
||||
return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err)
|
||||
}
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return "", "", 0, fmt.Errorf("error creating session to host %s: '%v'", host, err)
|
||||
return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
@ -176,7 +199,7 @@ func RunSSHCommand(cmd, host string, signer ssh.Signer) (string, string, int, er
|
||||
} else {
|
||||
// Some other kind of error happened (e.g. an IOError); consider the
|
||||
// SSH unsuccessful.
|
||||
err = fmt.Errorf("failed running `%s` on %s: '%v'", cmd, host, err)
|
||||
err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err)
|
||||
}
|
||||
}
|
||||
return bout.String(), berr.String(), code, err
|
||||
|
@ -24,6 +24,9 @@ import (
|
||||
|
||||
"github.com/golang/glog"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type testSSHServer struct {
|
||||
@ -159,3 +162,84 @@ func TestSSHTunnel(t *testing.T) {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type mockSSHDialer struct {
|
||||
network string
|
||||
addr string
|
||||
config *ssh.ClientConfig
|
||||
}
|
||||
|
||||
func (d *mockSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||
d.network = network
|
||||
d.addr = addr
|
||||
d.config = config
|
||||
return nil, fmt.Errorf("mock error from Dial")
|
||||
}
|
||||
|
||||
type mockSigner struct {
|
||||
}
|
||||
|
||||
func (s *mockSigner) PublicKey() ssh.PublicKey {
|
||||
panic("mockSigner.PublicKey not implemented")
|
||||
}
|
||||
|
||||
func (s *mockSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
||||
panic("mockSigner.Sign not implemented")
|
||||
}
|
||||
|
||||
func TestSSHUser(t *testing.T) {
|
||||
signer := &mockSigner{}
|
||||
|
||||
table := []struct {
|
||||
title string
|
||||
user string
|
||||
host string
|
||||
signer ssh.Signer
|
||||
command string
|
||||
expectUser string
|
||||
}{
|
||||
{
|
||||
title: "all values provided",
|
||||
user: "testuser",
|
||||
host: "testhost",
|
||||
signer: signer,
|
||||
command: "uptime",
|
||||
expectUser: "testuser",
|
||||
},
|
||||
{
|
||||
title: "empty user defaults to GetEnv(USER)",
|
||||
user: "",
|
||||
host: "testhost",
|
||||
signer: signer,
|
||||
command: "uptime",
|
||||
expectUser: os.Getenv("USER"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, item := range table {
|
||||
dialer := &mockSSHDialer{}
|
||||
|
||||
_, _, _, err := runSSHCommand(dialer, item.command, item.user, item.host, item.signer)
|
||||
if err == nil {
|
||||
t.Errorf("expected error (as mock returns error); did not get one")
|
||||
}
|
||||
errString := err.Error()
|
||||
if !strings.HasPrefix(errString, fmt.Sprintf("error getting SSH client to %s@%s:", item.expectUser, item.host)) {
|
||||
t.Errorf("unexpected error: %v", errString)
|
||||
}
|
||||
|
||||
if dialer.network != "tcp" {
|
||||
t.Errorf("unexpected network: %v", dialer.network)
|
||||
}
|
||||
|
||||
if dialer.config.User != item.expectUser {
|
||||
t.Errorf("unexpected user: %v", dialer.config.User)
|
||||
}
|
||||
if len(dialer.config.Auth) != 1 {
|
||||
t.Errorf("unexpected auth: %v", dialer.config.Auth)
|
||||
}
|
||||
// (No way to test Auth - nothing exported?)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1161,7 +1161,9 @@ func SSH(cmd, host, provider string) (string, string, int, error) {
|
||||
return "", "", 0, fmt.Errorf("error getting signer for provider %s: '%v'", provider, err)
|
||||
}
|
||||
|
||||
return util.RunSSHCommand(cmd, host, signer)
|
||||
user := os.Getenv("KUBE_SSH_USER")
|
||||
// RunSSHCommand will default to Getenv("USER") if user == ""
|
||||
return util.RunSSHCommand(cmd, user, host, signer)
|
||||
}
|
||||
|
||||
// getSigner returns an ssh.Signer for the provider ("gce", etc.) that can be
|
||||
|
Loading…
Reference in New Issue
Block a user