diff --git a/test/e2e/framework/BUILD b/test/e2e/framework/BUILD index ab4868a01d5..fae77f670a5 100644 --- a/test/e2e/framework/BUILD +++ b/test/e2e/framework/BUILD @@ -33,6 +33,7 @@ go_library( "rs_util.go", "service_util.go", "size.go", + "ssh.go", "statefulset_utils.go", "test_context.go", "upgrade_util.go", diff --git a/test/e2e/framework/ssh.go b/test/e2e/framework/ssh.go index fc1e1ca8149..2949d16d38b 100644 --- a/test/e2e/framework/ssh.go +++ b/test/e2e/framework/ssh.go @@ -17,13 +17,16 @@ limitations under the License. package framework import ( + "bytes" "fmt" "net" "os" "path/filepath" + "time" "golang.org/x/crypto/ssh" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/wait" clientset "k8s.io/client-go/kubernetes" sshutil "k8s.io/kubernetes/pkg/ssh" ) @@ -143,6 +146,14 @@ func SSH(cmd, host, provider string) (SSHResult, error) { result.User = os.Getenv("USER") } + if bastion := os.Getenv("KUBE_SSH_BASTION"); len(bastion) > 0 { + stdout, stderr, code, err := RunSSHCommandViaBastion(cmd, result.User, bastion, host, signer) + result.Stdout = stdout + result.Stderr = stderr + result.Code = code + return result, err + } + stdout, stderr, code, err := sshutil.RunSSHCommand(cmd, result.User, host, signer) result.Stdout = stdout result.Stderr = stderr @@ -151,6 +162,74 @@ func SSH(cmd, host, provider string) (SSHResult, error) { return result, err } +// RunSSHCommandViaBastion returns the stdout, stderr, and exit code from running cmd on +// host as specific user, along with any SSH-level error. It uses an SSH proxy to connect +// to bastion, then via that tunnel connects to the remote host. Similar to +// sshutil.RunSSHCommand but scoped to the needs of the test infrastructure. +func RunSSHCommandViaBastion(cmd, user, bastion, host string, signer ssh.Signer) (string, string, int, error) { + // Setup the config, dial the server, and open a session. + config := &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 150 * time.Second, + } + bastionClient, err := ssh.Dial("tcp", bastion, config) + if err != nil { + err = wait.Poll(5*time.Second, 20*time.Second, func() (bool, error) { + fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, bastion, err) + if bastionClient, err = ssh.Dial("tcp", bastion, config); err != nil { + return false, err + } + return true, nil + }) + } + if err != nil { + return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: %v", user, bastion, err) + } + defer bastionClient.Close() + + conn, err := bastionClient.Dial("tcp", host) + if err != nil { + return "", "", 0, fmt.Errorf("error dialing %s from bastion: %v", host, err) + } + defer conn.Close() + + ncc, chans, reqs, err := ssh.NewClientConn(conn, host, config) + if err != nil { + return "", "", 0, fmt.Errorf("error creating forwarding connection %s from bastion: %v", host, err) + } + client := ssh.NewClient(ncc, chans, reqs) + defer client.Close() + + session, err := client.NewSession() + if err != nil { + return "", "", 0, fmt.Errorf("error creating session to %s@%s from bastion: '%v'", user, host, err) + } + defer session.Close() + + // Run the command. + code := 0 + var bout, berr bytes.Buffer + session.Stdout, session.Stderr = &bout, &berr + if err = session.Run(cmd); err != nil { + // Check whether the command failed to run or didn't complete. + if exiterr, ok := err.(*ssh.ExitError); ok { + // If we got an ExitError and the exit code is nonzero, we'll + // consider the SSH itself successful (just that the command run + // errored on the host). + if code = exiterr.ExitStatus(); code != 0 { + err = nil + } + } else { + // Some other kind of error happened (e.g. an IOError); consider the + // SSH unsuccessful. + err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err) + } + } + return bout.String(), berr.String(), code, err +} + func LogSSHResult(result SSHResult) { remote := fmt.Sprintf("%s@%s", result.User, result.Host) Logf("ssh %s: command: %s", remote, result.Cmd)