diff --git a/pkg/cloudprovider/gce/gce.go b/pkg/cloudprovider/gce/gce.go index 4fdaa889c7a..d7cdb8a8d05 100644 --- a/pkg/cloudprovider/gce/gce.go +++ b/pkg/cloudprovider/gce/gce.go @@ -32,6 +32,7 @@ import ( "github.com/GoogleCloudPlatform/kubernetes/pkg/api/resource" "github.com/GoogleCloudPlatform/kubernetes/pkg/cloudprovider" "github.com/GoogleCloudPlatform/kubernetes/pkg/util" + "github.com/GoogleCloudPlatform/kubernetes/pkg/util/wait" "code.google.com/p/gcfg" compute "code.google.com/p/google-api-go-client/compute/v1" @@ -483,37 +484,46 @@ func (gce *GCECloud) getInstanceByName(name string) (*compute.Instance, error) { } func (gce *GCECloud) AddSSHKeyToAllInstances(user string, keyData []byte) error { - project, err := gce.service.Projects.Get(gce.projectID).Do() - if err != nil { - return err - } - hostname, err := os.Hostname() - if err != nil { - return err - } - keyString := fmt.Sprintf("%s:%s %s@%s", user, strings.TrimSpace(string(keyData)), user, hostname) - found := false - for _, item := range project.CommonInstanceMetadata.Items { - if item.Key == "sshKeys" { - item.Value = addKey(item.Value, keyString) - found = true - break + return wait.Poll(2*time.Second, 30*time.Second, func() (bool, error) { + project, err := gce.service.Projects.Get(gce.projectID).Do() + if err != nil { + glog.Errorf("Could not get project: %v", err) + return false, nil } - } - if !found { - // This is super unlikely, so log. - glog.Infof("Failed to find sshKeys metadata, creating a new item") - project.CommonInstanceMetadata.Items = append(project.CommonInstanceMetadata.Items, - &compute.MetadataItems{ - Key: "sshKeys", - Value: keyString, - }) - } - op, err := gce.service.Projects.SetCommonInstanceMetadata(gce.projectID, project.CommonInstanceMetadata).Do() - if err != nil { - return err - } - return gce.waitForGlobalOp(op) + hostname, err := os.Hostname() + if err != nil { + glog.Errorf("Could not get hostname: %v", err) + return false, nil + } + keyString := fmt.Sprintf("%s:%s %s@%s", user, strings.TrimSpace(string(keyData)), user, hostname) + found := false + for _, item := range project.CommonInstanceMetadata.Items { + if item.Key == "sshKeys" { + item.Value = addKey(item.Value, keyString) + found = true + break + } + } + if !found { + // This is super unlikely, so log. + glog.Infof("Failed to find sshKeys metadata, creating a new item") + project.CommonInstanceMetadata.Items = append(project.CommonInstanceMetadata.Items, + &compute.MetadataItems{ + Key: "sshKeys", + Value: keyString, + }) + } + op, err := gce.service.Projects.SetCommonInstanceMetadata(gce.projectID, project.CommonInstanceMetadata).Do() + if err != nil { + glog.Errorf("Could not Set Metadata: %v", err) + return false, nil + } + if err := gce.waitForGlobalOp(op); err != nil { + glog.Errorf("Could not Set Metadata: %v", err) + return false, nil + } + return true, nil + }) } func addKey(metadataBefore, keyString string) string { diff --git a/pkg/master/master.go b/pkg/master/master.go index 2439411a614..c659d28238b 100644 --- a/pkg/master/master.go +++ b/pkg/master/master.go @@ -210,7 +210,7 @@ type Master struct { InsecureHandler http.Handler // Used for secure proxy - tunnels util.SSHTunnelList + tunnels *util.SSHTunnelList tunnelsLock sync.Mutex installSSHKey InstallSSHKey } @@ -772,7 +772,7 @@ func (m *Master) Dial(net, addr string) (net.Conn, error) { } func (m *Master) needToReplaceTunnels(addrs []string) bool { - if len(m.tunnels) != len(addrs) { + if m.tunnels == nil || m.tunnels.Len() != len(addrs) { return true } // TODO (cjcullen): This doesn't need to be n^2 @@ -850,7 +850,7 @@ func (m *Master) setupSecureProxy(user, keyfile string) { if err := m.loadTunnels(user, keyfile); err != nil { glog.Errorf("Failed to load SSH Tunnels: %v", err) } - if len(m.tunnels) != 0 { + if m.tunnels != nil && m.tunnels.Len() != 0 { // Sleep for 10 seconds if we have some tunnels. // TODO (cjcullen): tunnels can lag behind actually existing nodes. time.Sleep(9 * time.Second) diff --git a/pkg/util/ssh.go b/pkg/util/ssh.go index d4eed8a3795..7d07249aebb 100644 --- a/pkg/util/ssh.go +++ b/pkg/util/ssh.go @@ -207,9 +207,11 @@ type SSHTunnelEntry struct { Tunnel *SSHTunnel } -type SSHTunnelList []SSHTunnelEntry +type SSHTunnelList struct { + entries []SSHTunnelEntry +} -func MakeSSHTunnels(user, keyfile string, addresses []string) (SSHTunnelList, error) { +func MakeSSHTunnels(user, keyfile string, addresses []string) (*SSHTunnelList, error) { tunnels := []SSHTunnelEntry{} for ix := range addresses { addr := addresses[ix] @@ -219,18 +221,22 @@ func MakeSSHTunnels(user, keyfile string, addresses []string) (SSHTunnelList, er } tunnels = append(tunnels, SSHTunnelEntry{addr, tunnel}) } - return tunnels, nil + return &SSHTunnelList{tunnels}, nil } -func (l SSHTunnelList) Open() error { - for ix := range l { - if err := l[ix].Tunnel.Open(); err != nil { - // Remove a failed Open from the list. - glog.Errorf("Failed to open tunnel %v: %v", l[ix], err) - l = append(l[:ix], l[ix+1:]...) +// Open attempts to open all tunnels in the list, and removes any tunnels that +// failed to open. +func (l *SSHTunnelList) Open() error { + var openTunnels []SSHTunnelEntry + for ix := range l.entries { + if err := l.entries[ix].Tunnel.Open(); err != nil { + glog.Errorf("Failed to open tunnel %v: %v", l.entries[ix], err) + } else { + openTunnels = append(openTunnels, l.entries[ix]) } } - if len(l) == 0 { + l.entries = openTunnels + if len(l.entries) == 0 { return errors.New("Failed to open any tunnels.") } return nil @@ -239,9 +245,9 @@ func (l SSHTunnelList) Open() error { // Close asynchronously closes all tunnels in the list after waiting for 1 // minute. Tunnels will still be open upon this function's return, but should // no longer be used. -func (l SSHTunnelList) Close() { - for ix := range l { - entry := l[ix] +func (l *SSHTunnelList) Close() { + for ix := range l.entries { + entry := l.entries[ix] go func() { defer HandleCrash() time.Sleep(1 * time.Minute) @@ -252,22 +258,26 @@ func (l SSHTunnelList) Close() { } } -func (l SSHTunnelList) Dial(network, addr string) (net.Conn, error) { - if len(l) == 0 { +func (l *SSHTunnelList) Dial(network, addr string) (net.Conn, error) { + if len(l.entries) == 0 { return nil, fmt.Errorf("Empty tunnel list.") } - return l[mathrand.Int()%len(l)].Tunnel.Dial(network, addr) + return l.entries[mathrand.Int()%len(l.entries)].Tunnel.Dial(network, addr) } -func (l SSHTunnelList) Has(addr string) bool { - for ix := range l { - if l[ix].Address == addr { +func (l *SSHTunnelList) Has(addr string) bool { + for ix := range l.entries { + if l.entries[ix].Address == addr { return true } } return false } +func (l *SSHTunnelList) Len() int { + return len(l.entries) +} + func EncodePrivateKey(private *rsa.PrivateKey) []byte { return pem.EncodeToMemory(&pem.Block{ Bytes: x509.MarshalPKCS1PrivateKey(private),