diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 88a17950637..40bab3cea51 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -161,7 +161,17 @@ type realSSHDialer struct{} var _ sshDialer = &realSSHDialer{} func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - return ssh.Dial(network, addr, config) + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + conn.SetReadDeadline(time.Time{}) + return ssh.NewClient(c, chans, reqs), nil } // timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang @@ -180,20 +190,8 @@ const sshDialTimeout = 150 * time.Second var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout} func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - var client *ssh.Client - errCh := make(chan error, 1) - go func() { - defer runtime.HandleCrash() - var err error - client, err = d.dialer.Dial(network, addr, config) - errCh <- err - }() - select { - case err := <-errCh: - return client, err - case <-time.After(d.timeout): - return nil, fmt.Errorf("timed out dialing %s:%s", network, addr) - } + config.Timeout = d.timeout + return d.dialer.Dial(network, addr, config) } // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on diff --git a/pkg/ssh/ssh_test.go b/pkg/ssh/ssh_test.go index 52cdfc504d5..a9facfe6325 100644 --- a/pkg/ssh/ssh_test.go +++ b/pkg/ssh/ssh_test.go @@ -329,38 +329,28 @@ func TestSSHUser(t *testing.T) { } -type slowDialer struct { - delay time.Duration - err error -} - -func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - time.Sleep(s.delay) - if s.err != nil { - return nil, s.err - } - return &ssh.Client{}, nil -} - func TestTimeoutDialer(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + testCases := []struct { - delay time.Duration timeout time.Duration - err error expectedErrString string }{ // delay > timeout should cause ssh.Dial to timeout. - {1 * time.Second, 0, nil, "timed out dialing"}, - // delay < timeout should return the result of the call to the dialer. - {0, 1 * time.Second, nil, ""}, - {0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"}, + {1, "i/o timeout"}, } for _, tc := range testCases { - dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout} - _, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{}) + dialer := &timeoutDialer{&realSSHDialer{}, tc.timeout} + _, err := dialer.Dial("tcp", listener.Addr().String(), &ssh.ClientConfig{}) if len(tc.expectedErrString) == 0 && err != nil || !strings.Contains(fmt.Sprint(err), tc.expectedErrString) { t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err) } } + + listener.Close() }