mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-25 20:53:33 +00:00
Use the builtin timeout provided by SSH
This commit is contained in:
parent
a01a493d5d
commit
e31dda98c1
@ -161,7 +161,17 @@ type realSSHDialer struct{}
|
|||||||
var _ sshDialer = &realSSHDialer{}
|
var _ sshDialer = &realSSHDialer{}
|
||||||
|
|
||||||
func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
return ssh.Dial(network, addr, config)
|
conn, err := net.DialTimeout(network, addr, config.Timeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||||
|
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn.SetReadDeadline(time.Time{})
|
||||||
|
return ssh.NewClient(c, chans, reqs), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
|
// timeoutDialer wraps an sshDialer with a timeout around Dial(). The golang
|
||||||
@ -180,20 +190,8 @@ const sshDialTimeout = 150 * time.Second
|
|||||||
var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout}
|
var realTimeoutDialer sshDialer = &timeoutDialer{&realSSHDialer{}, sshDialTimeout}
|
||||||
|
|
||||||
func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
var client *ssh.Client
|
config.Timeout = d.timeout
|
||||||
errCh := make(chan error, 1)
|
return d.dialer.Dial(network, addr, config)
|
||||||
go func() {
|
|
||||||
defer runtime.HandleCrash()
|
|
||||||
var err error
|
|
||||||
client, err = d.dialer.Dial(network, addr, config)
|
|
||||||
errCh <- err
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case err := <-errCh:
|
|
||||||
return client, err
|
|
||||||
case <-time.After(d.timeout):
|
|
||||||
return nil, fmt.Errorf("timed out dialing %s:%s", network, addr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
|
// RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
|
||||||
|
@ -329,38 +329,28 @@ func TestSSHUser(t *testing.T) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type slowDialer struct {
|
|
||||||
delay time.Duration
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *slowDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
|
|
||||||
time.Sleep(s.delay)
|
|
||||||
if s.err != nil {
|
|
||||||
return nil, s.err
|
|
||||||
}
|
|
||||||
return &ssh.Client{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTimeoutDialer(t *testing.T) {
|
func TestTimeoutDialer(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
delay time.Duration
|
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
err error
|
|
||||||
expectedErrString string
|
expectedErrString string
|
||||||
}{
|
}{
|
||||||
// delay > timeout should cause ssh.Dial to timeout.
|
// delay > timeout should cause ssh.Dial to timeout.
|
||||||
{1 * time.Second, 0, nil, "timed out dialing"},
|
{1, "i/o timeout"},
|
||||||
// delay < timeout should return the result of the call to the dialer.
|
|
||||||
{0, 1 * time.Second, nil, ""},
|
|
||||||
{0, 1 * time.Second, fmt.Errorf("test dial error"), "test dial error"},
|
|
||||||
}
|
}
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
dialer := &timeoutDialer{&slowDialer{tc.delay, tc.err}, tc.timeout}
|
dialer := &timeoutDialer{&realSSHDialer{}, tc.timeout}
|
||||||
_, err := dialer.Dial("tcp", "addr:port", &ssh.ClientConfig{})
|
_, err := dialer.Dial("tcp", listener.Addr().String(), &ssh.ClientConfig{})
|
||||||
if len(tc.expectedErrString) == 0 && err != nil ||
|
if len(tc.expectedErrString) == 0 && err != nil ||
|
||||||
!strings.Contains(fmt.Sprint(err), tc.expectedErrString) {
|
!strings.Contains(fmt.Sprint(err), tc.expectedErrString) {
|
||||||
t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err)
|
t.Errorf("Expected error to contain %q; got %v", tc.expectedErrString, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
listener.Close()
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user