mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-23 11:50:44 +00:00
Merge pull request #23843 from cjcullen/dialtimeout
Automatic merge from submit-queue Add a timeout to the sshDialer to prevent indefinite hangs. Prevents the SSH Dialer from hanging forever. Fixes a problem where SSH Tunnels get stuck trying to open. Addresses #23835.
This commit is contained in:
commit
52276bcc6c
@ -111,7 +111,7 @@ func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, err
|
|||||||
|
|
||||||
func (s *SSHTunnel) Open() error {
|
func (s *SSHTunnel) Open() error {
|
||||||
var err error
|
var err error
|
||||||
s.client, err = ssh.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
|
s.client, err = realTimeoutDialer.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
|
||||||
tunnelOpenCounter.Inc()
|
tunnelOpenCounter.Inc()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tunnelOpenFailCounter.Inc()
|
tunnelOpenFailCounter.Inc()
|
||||||
@ -163,11 +163,43 @@ func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*s
|
|||||||
return ssh.Dial(network, addr, config)
|
return ssh.Dial(network, addr, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
|
||||||
|
// ssh library can hang indefinitely inside the Dial() call (see issue #23835).
|
||||||
|
// Wrapping all Dial() calls with a conservative timeout provides safety against
|
||||||
|
// getting stuck on that.
|
||||||
|
type timeoutDialer struct {
|
||||||
|
dialer sshDialer
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// 150 seconds is longer than the underlying default TCP backoff delay (127
|
||||||
|
// seconds). This timeout is only intended to catch otherwise uncaught hangs.
|
||||||
|
const sshDialTimeout = 150 * time.Second
|
||||||
|
|
||||||
|
var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout}
|
||||||
|
|
||||||
|
func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
var client *ssh.Client
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer runtime.HandleCrash()
|
||||||
|
var err error
|
||||||
|
client, err = d.dialer.Dial(network, addr, config)
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
return client, err
|
||||||
|
case <-time.After(d.timeout):
|
||||||
|
return nil, fmt.Errorf("timed out dialing %s:%s", network, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
|
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
|
||||||
// host as specific user, along with any SSH-level error.
|
// host as specific user, along with any SSH-level error.
|
||||||
// If user=="", it will default (like SSH) to os.Getenv("USER")
|
// If user=="", it will default (like SSH) to os.Getenv("USER")
|
||||||
func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
|
func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
|
||||||
return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer, true)
|
return runSSHCommand(realTimeoutDialer, cmd, user, host, signer, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal implementation of runSSHCommand, for testing
|
// Internal implementation of runSSHCommand, for testing
|
||||||
|
@ -328,3 +328,39 @@ func TestSSHUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type slowDialer struct {
|
||||||
|
delay time.Duration
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
time.Sleep(s.delay)
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return &ssh.Client{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimeoutDialer(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
delay time.Duration
|
||||||
|
timeout time.Duration
|
||||||
|
err error
|
||||||
|
expectedErrString string
|
||||||
|
}{
|
||||||
|
// delay > timeout should cause ssh.Dial to timeout.
|
||||||
|
{1 * time.Second, 0, nil, "timed out dialing"},
|
||||||
|
// delay < timeout should return the result of the call to the dialer.
|
||||||
|
{0, 1 * time.Second, nil, ""},
|
||||||
|
{0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout}
|
||||||
|
_, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{})
|
||||||
|
if len(tc.expectedErrString) == 0 && err != nil ||
|
||||||
|
!strings.Contains(fmt.Sprint(err), tc.expectedErrString) {
|
||||||
|
t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user