diff --git a/pkg/master/tunneler.go b/pkg/master/tunneler.go index 9cb04b34c20..e79b9e5cf0f 100644 --- a/pkg/master/tunneler.go +++ b/pkg/master/tunneler.go @@ -24,6 +24,7 @@ import ( "sync/atomic" "time" + "k8s.io/kubernetes/pkg/ssh" "k8s.io/kubernetes/pkg/util" "github.com/golang/glog" @@ -47,7 +48,7 @@ type SSHTunneler struct { InstallSSHKey InstallSSHKey HealthCheckURL *url.URL - tunnels *util.SSHTunnelList + tunnels *ssh.SSHTunnelList lastSync int64 // Seconds since Epoch lastSyncMetric prometheus.GaugeFunc clock util.Clock @@ -97,7 +98,7 @@ func (c *SSHTunneler) Run(getAddresses AddressFunc) { } } - c.tunnels = util.NewSSHTunnelList(c.SSHUser, c.SSHKeyfile, c.HealthCheckURL, c.stopChan) + c.tunnels = ssh.NewSSHTunnelList(c.SSHUser, c.SSHKeyfile, c.HealthCheckURL, c.stopChan) // Sync loop to ensure that the SSH key has been installed. c.installSSHKeySyncLoop(c.SSHUser, publicKeyFile) // Sync tunnelList w/ nodes. @@ -129,12 +130,12 @@ func (c *SSHTunneler) installSSHKeySyncLoop(user, publicKeyfile string) { glog.Error("Won't attempt to install ssh key: InstallSSHKey function is nil") return } - key, err := util.ParsePublicKeyFromFile(publicKeyfile) + key, err := ssh.ParsePublicKeyFromFile(publicKeyfile) if err != nil { glog.Errorf("Failed to load public key: %v", err) return } - keyData, err := util.EncodeSSHKey(key) + keyData, err := ssh.EncodeSSHKey(key) if err != nil { glog.Errorf("Failed to encode public key: %v", err) return @@ -161,7 +162,7 @@ func (c *SSHTunneler) nodesSyncLoop() { } func generateSSHKey(privateKeyfile, publicKeyfile string) error { - private, public, err := util.GenerateKey(2048) + private, public, err := ssh.GenerateKey(2048) if err != nil { return err } @@ -176,10 +177,10 @@ func generateSSHKey(privateKeyfile, publicKeyfile string) error { glog.Errorf("Failed to remove stale private key: %v", err) } } - if err := ioutil.WriteFile(privateKeyfile, util.EncodePrivateKey(private), 0600); err != nil { + if err := ioutil.WriteFile(privateKeyfile, ssh.EncodePrivateKey(private), 0600); err != nil { return err } - publicKeyBytes, err := util.EncodePublicKey(public) + publicKeyBytes, err := ssh.EncodePublicKey(public) if err != nil { return err } diff --git a/pkg/util/ssh.go b/pkg/ssh/ssh.go similarity index 96% rename from pkg/util/ssh.go rename to pkg/ssh/ssh.go index 04a06afb3e7..a015bee916f 100644 --- a/pkg/util/ssh.go +++ b/pkg/ssh/ssh.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package util +package ssh import ( "bytes" @@ -39,8 +39,10 @@ import ( "github.com/prometheus/client_golang/prometheus" "golang.org/x/crypto/ssh" + "k8s.io/kubernetes/pkg/util" utilnet "k8s.io/kubernetes/pkg/util/net" "k8s.io/kubernetes/pkg/util/runtime" + "k8s.io/kubernetes/pkg/util/wait" ) var ( @@ -166,11 +168,11 @@ func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*s // host as specific user, along with any SSH-level error. // If user=="", it will default (like SSH) to os.Getenv("USER") func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) { - return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer) + return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer, true) } // Internal implementation of runSSHCommand, for testing -func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer) (string, string, int, error) { +func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer, retry bool) (string, string, int, error) { if user == "" { user = os.Getenv("USER") } @@ -180,6 +182,15 @@ func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer) Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, } client, err := dialer.Dial("tcp", host, config) + if err != nil && retry { + err = wait.Poll(5*time.Second, 20*time.Second, func() (bool, error) { + fmt.Printf("error dialing %s@%s: '%v', retrying\n", user, host, err) + if client, err = dialer.Dial("tcp", host, config); err != nil { + return false, nil + } + return true, nil + }) + } if err != nil { return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err) } @@ -286,7 +297,7 @@ func NewSSHTunnelList(user, keyfile string, healthCheckURL *url.URL, stopChan ch healthCheckURL: healthCheckURL, } healthCheckPoll := 1 * time.Minute - go Until(func() { + go util.Until(func() { l.tunnelsLock.Lock() defer l.tunnelsLock.Unlock() // Healthcheck each tunnel every minute diff --git a/pkg/util/ssh_test.go b/pkg/ssh/ssh_test.go similarity index 99% rename from pkg/util/ssh_test.go rename to pkg/ssh/ssh_test.go index 072cfff6f1a..bbc1641e9aa 100644 --- a/pkg/util/ssh_test.go +++ b/pkg/ssh/ssh_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package util +package ssh import ( "fmt" @@ -304,7 +304,7 @@ func TestSSHUser(t *testing.T) { for _, item := range table { dialer := &mockSSHDialer{} - _, _, _, err := runSSHCommand(dialer, item.command, item.user, item.host, item.signer) + _, _, _, err := runSSHCommand(dialer, item.command, item.user, item.host, item.signer, false) if err == nil { t.Errorf("expected error (as mock returns error); did not get one") } diff --git a/test/e2e/util.go b/test/e2e/util.go index f705e722003..cc4533f434f 100644 --- a/test/e2e/util.go +++ b/test/e2e/util.go @@ -51,6 +51,7 @@ import ( "k8s.io/kubernetes/pkg/kubelet/util/format" "k8s.io/kubernetes/pkg/labels" "k8s.io/kubernetes/pkg/runtime" + sshutil "k8s.io/kubernetes/pkg/ssh" "k8s.io/kubernetes/pkg/util" deploymentutil "k8s.io/kubernetes/pkg/util/deployment" "k8s.io/kubernetes/pkg/util/sets" @@ -2227,7 +2228,7 @@ func SSH(cmd, host, provider string) (SSHResult, error) { result.User = os.Getenv("USER") } - stdout, stderr, code, err := util.RunSSHCommand(cmd, result.User, host, signer) + stdout, stderr, code, err := sshutil.RunSSHCommand(cmd, result.User, host, signer) result.Stdout = stdout result.Stderr = stderr result.Code = code @@ -2332,7 +2333,7 @@ func getSigner(provider string) (ssh.Signer, error) { // If there is an env. variable override, use that. aws_keyfile := os.Getenv("AWS_SSH_KEY") if len(aws_keyfile) != 0 { - return util.MakePrivateKeySignerFromFile(aws_keyfile) + return sshutil.MakePrivateKeySignerFromFile(aws_keyfile) } // Otherwise revert to home dir keyfile = "kube_aws_rsa" @@ -2341,7 +2342,7 @@ func getSigner(provider string) (ssh.Signer, error) { } key := filepath.Join(keydir, keyfile) - return util.MakePrivateKeySignerFromFile(key) + return sshutil.MakePrivateKeySignerFromFile(key) } // checkPodsRunning returns whether all pods whose names are listed in podNames