fix locking around ssh tunnels

This commit is contained in:
Daniel Smith 2015-06-18 16:31:54 -07:00
parent 847d771198
commit 4126622388
2 changed files with 36 additions and 7 deletions

View File

@ -20,6 +20,7 @@ import (
"bytes"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/http/pprof"
@ -774,9 +775,23 @@ func findExternalAddress(node *api.Node) (string, error) {
}
func (m *Master) Dial(net, addr string) (net.Conn, error) {
m.tunnelsLock.Lock()
defer m.tunnelsLock.Unlock()
return m.tunnels.Dial(net, addr)
// Only lock while picking a tunnel.
tunnel, err := func() (util.SSHTunnelEntry, error) {
m.tunnelsLock.Lock()
defer m.tunnelsLock.Unlock()
return m.tunnels.PickRandomTunnel()
}()
if err != nil {
return nil, err
}
start := time.Now()
id := rand.Int63() // So you can match begins/ends in the log.
glog.V(3).Infof("[%x: %v] Dialing...", id, tunnel.Address)
defer func() {
glog.V(3).Infof("[%x: %v] Dialed in %v.", id, tunnel.Address, time.Now().Sub(start))
}()
return tunnel.Tunnel.Dial(net, addr)
}
func (m *Master) needToReplaceTunnels(addrs []string) bool {

View File

@ -81,8 +81,8 @@ func NewSSHTunnel(user, keyfile, host string) (*SSHTunnel, error) {
return makeSSHTunnel(user, signer, host)
}
func NewSSHTunnelFromBytes(user string, buffer []byte, host string) (*SSHTunnel, error) {
signer, err := MakePrivateKeySignerFromBytes(buffer)
func NewSSHTunnelFromBytes(user string, privateKey []byte, host string) (*SSHTunnel, error) {
signer, err := MakePrivateKeySignerFromBytes(privateKey)
if err != nil {
return nil, err
}
@ -214,11 +214,13 @@ func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
return rsaKey, nil
}
// Should be thread safe.
type SSHTunnelEntry struct {
Address string
Tunnel *SSHTunnel
}
// Not thread safe!
type SSHTunnelList struct {
entries []SSHTunnelEntry
}
@ -271,11 +273,23 @@ func (l *SSHTunnelList) Close() {
}
}
/* this will make sense if we move the lock into SSHTunnelList.
func (l *SSHTunnelList) Dial(network, addr string) (net.Conn, error) {
if len(l.entries) == 0 {
return nil, fmt.Errorf("Empty tunnel list.")
return nil, fmt.Errorf("empty tunnel list.")
}
return l.entries[mathrand.Int()%len(l.entries)].Tunnel.Dial(network, addr)
n := mathrand.Intn(len(l.entries))
return l.entries[n].Tunnel.Dial(network, addr)
}
*/
// Returns a random tunnel, xor an error if there are none.
func (l *SSHTunnelList) PickRandomTunnel() (SSHTunnelEntry, error) {
if len(l.entries) == 0 {
return SSHTunnelEntry{}, fmt.Errorf("empty tunnel list.")
}
n := mathrand.Intn(len(l.entries))
return l.entries[n], nil
}
func (l *SSHTunnelList) Has(addr string) bool {