Merge pull request #9848 from cjcullen/fwfix

Fix mislooping in ssh.go. Add retries to AddSSHKeys.
This commit is contained in:
Brendan Burns 2015-06-16 11:10:30 -07:00
commit 96c244eacf
3 changed files with 72 additions and 52 deletions

View File

@ -32,6 +32,7 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/api/resource"
"github.com/GoogleCloudPlatform/kubernetes/pkg/cloudprovider"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/wait"
"code.google.com/p/gcfg"
compute "code.google.com/p/google-api-go-client/compute/v1"
@ -483,37 +484,46 @@ func (gce *GCECloud) getInstanceByName(name string) (*compute.Instance, error) {
}
func (gce *GCECloud) AddSSHKeyToAllInstances(user string, keyData []byte) error {
project, err := gce.service.Projects.Get(gce.projectID).Do()
if err != nil {
return err
}
hostname, err := os.Hostname()
if err != nil {
return err
}
keyString := fmt.Sprintf("%s:%s %s@%s", user, strings.TrimSpace(string(keyData)), user, hostname)
found := false
for _, item := range project.CommonInstanceMetadata.Items {
if item.Key == "sshKeys" {
item.Value = addKey(item.Value, keyString)
found = true
break
return wait.Poll(2*time.Second, 30*time.Second, func() (bool, error) {
project, err := gce.service.Projects.Get(gce.projectID).Do()
if err != nil {
glog.Errorf("Could not get project: %v", err)
return false, nil
}
}
if !found {
// This is super unlikely, so log.
glog.Infof("Failed to find sshKeys metadata, creating a new item")
project.CommonInstanceMetadata.Items = append(project.CommonInstanceMetadata.Items,
&compute.MetadataItems{
Key: "sshKeys",
Value: keyString,
})
}
op, err := gce.service.Projects.SetCommonInstanceMetadata(gce.projectID, project.CommonInstanceMetadata).Do()
if err != nil {
return err
}
return gce.waitForGlobalOp(op)
hostname, err := os.Hostname()
if err != nil {
glog.Errorf("Could not get hostname: %v", err)
return false, nil
}
keyString := fmt.Sprintf("%s:%s %s@%s", user, strings.TrimSpace(string(keyData)), user, hostname)
found := false
for _, item := range project.CommonInstanceMetadata.Items {
if item.Key == "sshKeys" {
item.Value = addKey(item.Value, keyString)
found = true
break
}
}
if !found {
// This is super unlikely, so log.
glog.Infof("Failed to find sshKeys metadata, creating a new item")
project.CommonInstanceMetadata.Items = append(project.CommonInstanceMetadata.Items,
&compute.MetadataItems{
Key: "sshKeys",
Value: keyString,
})
}
op, err := gce.service.Projects.SetCommonInstanceMetadata(gce.projectID, project.CommonInstanceMetadata).Do()
if err != nil {
glog.Errorf("Could not Set Metadata: %v", err)
return false, nil
}
if err := gce.waitForGlobalOp(op); err != nil {
glog.Errorf("Could not Set Metadata: %v", err)
return false, nil
}
return true, nil
})
}
func addKey(metadataBefore, keyString string) string {

View File

@ -210,7 +210,7 @@ type Master struct {
InsecureHandler http.Handler
// Used for secure proxy
tunnels util.SSHTunnelList
tunnels *util.SSHTunnelList
tunnelsLock sync.Mutex
installSSHKey InstallSSHKey
}
@ -772,7 +772,7 @@ func (m *Master) Dial(net, addr string) (net.Conn, error) {
}
func (m *Master) needToReplaceTunnels(addrs []string) bool {
if len(m.tunnels) != len(addrs) {
if m.tunnels == nil || m.tunnels.Len() != len(addrs) {
return true
}
// TODO (cjcullen): This doesn't need to be n^2
@ -850,7 +850,7 @@ func (m *Master) setupSecureProxy(user, keyfile string) {
if err := m.loadTunnels(user, keyfile); err != nil {
glog.Errorf("Failed to load SSH Tunnels: %v", err)
}
if len(m.tunnels) != 0 {
if m.tunnels != nil && m.tunnels.Len() != 0 {
// Sleep for 10 seconds if we have some tunnels.
// TODO (cjcullen): tunnels can lag behind actually existing nodes.
time.Sleep(9 * time.Second)

View File

@ -207,9 +207,11 @@ type SSHTunnelEntry struct {
Tunnel *SSHTunnel
}
type SSHTunnelList []SSHTunnelEntry
type SSHTunnelList struct {
entries []SSHTunnelEntry
}
func MakeSSHTunnels(user, keyfile string, addresses []string) (SSHTunnelList, error) {
func MakeSSHTunnels(user, keyfile string, addresses []string) (*SSHTunnelList, error) {
tunnels := []SSHTunnelEntry{}
for ix := range addresses {
addr := addresses[ix]
@ -219,18 +221,22 @@ func MakeSSHTunnels(user, keyfile string, addresses []string) (SSHTunnelList, er
}
tunnels = append(tunnels, SSHTunnelEntry{addr, tunnel})
}
return tunnels, nil
return &SSHTunnelList{tunnels}, nil
}
func (l SSHTunnelList) Open() error {
for ix := range l {
if err := l[ix].Tunnel.Open(); err != nil {
// Remove a failed Open from the list.
glog.Errorf("Failed to open tunnel %v: %v", l[ix], err)
l = append(l[:ix], l[ix+1:]...)
// Open attempts to open all tunnels in the list, and removes any tunnels that
// failed to open.
func (l *SSHTunnelList) Open() error {
var openTunnels []SSHTunnelEntry
for ix := range l.entries {
if err := l.entries[ix].Tunnel.Open(); err != nil {
glog.Errorf("Failed to open tunnel %v: %v", l.entries[ix], err)
} else {
openTunnels = append(openTunnels, l.entries[ix])
}
}
if len(l) == 0 {
l.entries = openTunnels
if len(l.entries) == 0 {
return errors.New("Failed to open any tunnels.")
}
return nil
@ -239,9 +245,9 @@ func (l SSHTunnelList) Open() error {
// Close asynchronously closes all tunnels in the list after waiting for 1
// minute. Tunnels will still be open upon this function's return, but should
// no longer be used.
func (l SSHTunnelList) Close() {
for ix := range l {
entry := l[ix]
func (l *SSHTunnelList) Close() {
for ix := range l.entries {
entry := l.entries[ix]
go func() {
defer HandleCrash()
time.Sleep(1 * time.Minute)
@ -252,22 +258,26 @@ func (l SSHTunnelList) Close() {
}
}
func (l SSHTunnelList) Dial(network, addr string) (net.Conn, error) {
if len(l) == 0 {
func (l *SSHTunnelList) Dial(network, addr string) (net.Conn, error) {
if len(l.entries) == 0 {
return nil, fmt.Errorf("Empty tunnel list.")
}
return l[mathrand.Int()%len(l)].Tunnel.Dial(network, addr)
return l.entries[mathrand.Int()%len(l.entries)].Tunnel.Dial(network, addr)
}
func (l SSHTunnelList) Has(addr string) bool {
for ix := range l {
if l[ix].Address == addr {
func (l *SSHTunnelList) Has(addr string) bool {
for ix := range l.entries {
if l.entries[ix].Address == addr {
return true
}
}
return false
}
func (l *SSHTunnelList) Len() int {
return len(l.entries)
}
func EncodePrivateKey(private *rsa.PrivateKey) []byte {
return pem.EncodeToMemory(&pem.Block{
Bytes: x509.MarshalPKCS1PrivateKey(private),