mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-11 13:02:14 +00:00
Merge pull request #9905 from justinsb/aws_ssh_user
Don't assume we always SSH as the current user
This commit is contained in:
commit
0f59847cc5
@ -142,21 +142,44 @@ func (s *SSHTunnel) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Interface to allow mocking of ssh.Dial, for testing SSH
|
||||||
|
type sshDialer interface {
|
||||||
|
Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Real implementation of sshDialer
|
||||||
|
type realSSHDialer struct{}
|
||||||
|
|
||||||
|
var _ sshDialer = &realSSHDialer{}
|
||||||
|
|
||||||
|
func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
return ssh.Dial(network, addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
|
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
|
||||||
// host along with any SSH-level error.
|
// host as specific user, along with any SSH-level error.
|
||||||
func RunSSHCommand(cmd, host string, signer ssh.Signer) (string, string, int, error) {
|
// If user=="", it will default (like SSH) to os.Getenv("USER")
|
||||||
|
func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
|
||||||
|
return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal implementation of runSSHCommand, for testing
|
||||||
|
func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
|
||||||
|
if user == "" {
|
||||||
|
user = os.Getenv("USER")
|
||||||
|
}
|
||||||
// Setup the config, dial the server, and open a session.
|
// Setup the config, dial the server, and open a session.
|
||||||
config := &ssh.ClientConfig{
|
config := &ssh.ClientConfig{
|
||||||
User: os.Getenv("USER"),
|
User: user,
|
||||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
|
||||||
}
|
}
|
||||||
client, err := ssh.Dial("tcp", host, config)
|
client, err := dialer.Dial("tcp", host, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, fmt.Errorf("error getting SSH client to host %s: '%v'", host, err)
|
return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err)
|
||||||
}
|
}
|
||||||
session, err := client.NewSession()
|
session, err := client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", 0, fmt.Errorf("error creating session to host %s: '%v'", host, err)
|
return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err)
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
@ -176,7 +199,7 @@ func RunSSHCommand(cmd, host string, signer ssh.Signer) (string, string, int, er
|
|||||||
} else {
|
} else {
|
||||||
// Some other kind of error happened (e.g. an IOError); consider the
|
// Some other kind of error happened (e.g. an IOError); consider the
|
||||||
// SSH unsuccessful.
|
// SSH unsuccessful.
|
||||||
err = fmt.Errorf("failed running `%s` on %s: '%v'", cmd, host, err)
|
err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return bout.String(), berr.String(), code, err
|
return bout.String(), berr.String(), code, err
|
||||||
|
@ -24,6 +24,9 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/glog"
|
"github.com/golang/glog"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testSSHServer struct {
|
type testSSHServer struct {
|
||||||
@ -159,3 +162,84 @@ func TestSSHTunnel(t *testing.T) {
|
|||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockSSHDialer struct {
|
||||||
|
network string
|
||||||
|
addr string
|
||||||
|
config *ssh.ClientConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *mockSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
d.network = network
|
||||||
|
d.addr = addr
|
||||||
|
d.config = config
|
||||||
|
return nil, fmt.Errorf("mock error from Dial")
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockSigner struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockSigner) PublicKey() ssh.PublicKey {
|
||||||
|
panic("mockSigner.PublicKey not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
||||||
|
panic("mockSigner.Sign not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHUser(t *testing.T) {
|
||||||
|
signer := &mockSigner{}
|
||||||
|
|
||||||
|
table := []struct {
|
||||||
|
title string
|
||||||
|
user string
|
||||||
|
host string
|
||||||
|
signer ssh.Signer
|
||||||
|
command string
|
||||||
|
expectUser string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
title: "all values provided",
|
||||||
|
user: "testuser",
|
||||||
|
host: "testhost",
|
||||||
|
signer: signer,
|
||||||
|
command: "uptime",
|
||||||
|
expectUser: "testuser",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "empty user defaults to GetEnv(USER)",
|
||||||
|
user: "",
|
||||||
|
host: "testhost",
|
||||||
|
signer: signer,
|
||||||
|
command: "uptime",
|
||||||
|
expectUser: os.Getenv("USER"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range table {
|
||||||
|
dialer := &mockSSHDialer{}
|
||||||
|
|
||||||
|
_, _, _, err := runSSHCommand(dialer, item.command, item.user, item.host, item.signer)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error (as mock returns error); did not get one")
|
||||||
|
}
|
||||||
|
errString := err.Error()
|
||||||
|
if !strings.HasPrefix(errString, fmt.Sprintf("error getting SSH client to %s@%s:", item.expectUser, item.host)) {
|
||||||
|
t.Errorf("unexpected error: %v", errString)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dialer.network != "tcp" {
|
||||||
|
t.Errorf("unexpected network: %v", dialer.network)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dialer.config.User != item.expectUser {
|
||||||
|
t.Errorf("unexpected user: %v", dialer.config.User)
|
||||||
|
}
|
||||||
|
if len(dialer.config.Auth) != 1 {
|
||||||
|
t.Errorf("unexpected auth: %v", dialer.config.Auth)
|
||||||
|
}
|
||||||
|
// (No way to test Auth - nothing exported?)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -1161,7 +1161,9 @@ func SSH(cmd, host, provider string) (string, string, int, error) {
|
|||||||
return "", "", 0, fmt.Errorf("error getting signer for provider %s: '%v'", provider, err)
|
return "", "", 0, fmt.Errorf("error getting signer for provider %s: '%v'", provider, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return util.RunSSHCommand(cmd, host, signer)
|
user := os.Getenv("KUBE_SSH_USER")
|
||||||
|
// RunSSHCommand will default to Getenv("USER") if user == ""
|
||||||
|
return util.RunSSHCommand(cmd, user, host, signer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSigner returns an ssh.Signer for the provider ("gce", etc.) that can be
|
// getSigner returns an ssh.Signer for the provider ("gce", etc.) that can be
|
||||||
|
Loading…
Reference in New Issue
Block a user