diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 69dbec685bf..ed496d975a0 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -111,7 +111,7 @@ func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, err func (s *SSHTunnel) Open() error { var err error - s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config) + s.client, err = defaultTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config) tunnelOpenCounter.Inc() if err != nil { tunnelOpenFailCounter.Inc() @@ -154,21 +154,9 @@ type sshDialer interface { Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) } -// Real implementation of sshDialer -type realSSHDialer struct{} - -var _ sshDialer = &realSSHDialer{} - -func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - return ssh.Dial(network, addr, config) -} - -// timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang +// timeoutDialer implements a Dial() method that will timeout. The golang // ssh library can hang indefinitely inside the Dial() call (see issue #23835). -// Wrapping all Dial() calls with a conservative timeout provides safety against -// getting stuck on that. type timeoutDialer struct { - dialer sshDialer timeout time.Duration } @@ -176,30 +164,32 @@ type timeoutDialer struct { // seconds). This timeout is only intended to catch otherwise uncaught hangs. const sshDialTimeout = 150 * time.Second -var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout} +var defaultTimeoutDialer sshDialer = &timeoutDialer{sshDialTimeout} func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - var client *ssh.Client - errCh := make(chan error, 1) - go func() { - defer runtime.HandleCrash() - var err error - client, err = d.dialer.Dial(network, addr, config) - errCh <- err - }() - select { - case err := <-errCh: - return client, err - case <-time.After(d.timeout): - return nil, fmt.Errorf("timed out dialing %s:%s", network, addr) + conn, err := net.Dial(network, addr) + if err != nil { + return nil, err } + conn.SetDeadline(time.Now().Add(d.timeout)) + // set to 0 so that conn will not time out after Dial. + defer func() { + conn.SetDeadline(time.Time{}) + }() + // if conn times out, the NewClientConn will close it, so we will not end up + // with hanging goroutines or open file descriptors. + c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return ssh.NewClient(c, chans, reqs), nil } // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on // host as specific user, along with any SSH-level error. // If user=="", it will default (like SSH) to os.Getenv("USER") func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) { - return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true) + return runSSHCommand(defaultTimeoutDialer, cmd, user, host, signer, true) } // Internal implementation of runSSHCommand, for testing diff --git a/pkg/ssh/ssh_test.go b/pkg/ssh/ssh_test.go index 52cdfc504d5..da7570a1ac3 100644 --- a/pkg/ssh/ssh_test.go +++ b/pkg/ssh/ssh_test.go @@ -329,38 +329,49 @@ func TestSSHUser(t *testing.T) { } -type slowDialer struct { - delay time.Duration - err error -} - -func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - time.Sleep(s.delay) - if s.err != nil { - return nil, s.err - } - return &ssh.Client{}, nil -} - func TestTimeoutDialer(t *testing.T) { testCases := []struct { - delay time.Duration timeout time.Duration - err error expectedErrString string }{ - // delay > timeout should cause ssh.Dial to timeout. - {1 * time.Second, 0, nil, "timed out dialing"}, - // delay < timeout should return the result of the call to the dialer. - {0, 1 * time.Second, nil, ""}, - {0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"}, + // should cause ssh.Dial to timeout. + {0, "i/o timeout"}, + // should succeed + {1 * time.Second, ""}, } for _, tc := range testCases { - dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout} - _, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{}) + // setup + private, _, err := GenerateKey(2048) + if err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + server, err := runTestSSHServer("foo", "bar") + if err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + privateData := EncodePrivateKey(private) + tunnel, err := NewSSHTunnelFromBytes("foo", privateData, server.Host) + if err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + tunnel.SSHPort = server.Port + + // test the dialer + dialer := &timeoutDialer{tc.timeout} + client, err := dialer.Dial("tcp", net.JoinHostPort(tunnel.Host, tunnel.SSHPort), tunnel.Config) if len(tc.expectedErrString) == 0 && err != nil || !strings.Contains(fmt.Sprint(err), tc.expectedErrString) { t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err) } + if len(tc.expectedErrString) == 0 { + // verify the connection doesn't timeout after the handshake is done. + time.Sleep(tc.timeout + 1*time.Second) + if _, _, err := client.OpenChannel("direct-tcpip", nil); err != nil { + t.Errorf("unexpected error %v", err) + } + } } }