From cde4f6d613b42b06946a6b68b0d55f11e8aedadb Mon Sep 17 00:00:00 2001 From: CJ Cullen Date: Mon, 4 Apr 2016 15:51:49 -0700 Subject: [PATCH] Add a timeout to the sshDialer to prevent indefinite hangs. --- pkg/ssh/ssh.go | 36 ++++++++++++++++++++++++++++++++++-- pkg/ssh/ssh_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index a85bb760537..44541574737 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -111,7 +111,7 @@ func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, err func (s *SSHTunnel) Open() error { var err error - s.client, err = ssh.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config) + s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config) tunnelOpenCounter.Inc() if err != nil { tunnelOpenFailCounter.Inc() @@ -163,11 +163,43 @@ func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*s return ssh.Dial(network, addr, config) } +// timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang +// ssh library can hang indefinitely inside the Dial() call (see issue #23835). +// Wrapping all Dial() calls with a conservative timeout provides safety against +// getting stuck on that. +type timeoutDialer struct { + dialer sshDialer + timeout time.Duration +} + +// 150 seconds is longer than the underlying default TCP backoff delay (127 +// seconds). This timeout is only intended to catch otherwise uncaught hangs. +const sshDialTimeout = 150 * time.Second + +var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout} + +func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + var client *ssh.Client + errCh := make(chan error, 1) + go func() { + defer runtime.HandleCrash() + var err error + client, err = d.dialer.Dial(network, addr, config) + errCh <- err + }() + select { + case err := <-errCh: + return client, err + case <-time.After(d.timeout): + return nil, fmt.Errorf("timed out dialing %s:%s", network, addr) + } +} + // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on // host as specific user, along with any SSH-level error. // If user=="", it will default (like SSH) to os.Getenv("USER") func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) { - return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer, true) + return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true) } // Internal implementation of runSSHCommand, for testing diff --git a/pkg/ssh/ssh_test.go b/pkg/ssh/ssh_test.go index bbc1641e9aa..d8ce2d9a613 100644 --- a/pkg/ssh/ssh_test.go +++ b/pkg/ssh/ssh_test.go @@ -328,3 +328,39 @@ func TestSSHUser(t *testing.T) { } } + +type slowDialer struct { + delay time.Duration + err error +} + +func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + time.Sleep(s.delay) + if s.err != nil { + return nil, s.err + } + return &ssh.Client{}, nil +} + +func TestTimeoutDialer(t *testing.T) { + testCases := []struct { + delay time.Duration + timeout time.Duration + err error + expectedErrString string + }{ + // delay > timeout should cause ssh.Dial to timeout. + {1 * time.Second, 0, nil, "timed out dialing"}, + // delay < timeout should return the result of the call to the dialer. + {0, 1 * time.Second, nil, ""}, + {0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"}, + } + for _, tc := range testCases { + dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout} + _, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{}) + if len(tc.expectedErrString) == 0 && err != nil || + !strings.Contains(fmt.Sprint(err), tc.expectedErrString) { + t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err) + } + } +}