Refactor proxy code to make room for UDP

This commit is contained in:
Tim Hockin 2014-09-10 13:44:20 -07:00
parent 9f275c81ac
commit cad6122fe4
2 changed files with 153 additions and 128 deletions

View File

@ -21,6 +21,7 @@ import (
"io" "io"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -32,11 +33,89 @@ import (
type serviceInfo struct { type serviceInfo struct {
name string name string
port int port int
listener net.Listener protocol string
socket proxySocket
mu sync.Mutex // protects active mu sync.Mutex // protects active
active bool active bool
} }
// Abstraction over TCP/UDP sockets which are proxied.
type proxySocket interface {
// Addr gets the net.Addr for a proxySocket.
Addr() net.Addr
// Close stops the proxySocket from accepting incoming connections.
Close() error
// ProxyLoop proxies incoming connections for the specified service to the service endpoints.
ProxyLoop(service string, proxier *Proxier)
}
type tcpProxySocket struct {
net.Listener
}
func (tcp *tcpProxySocket) ProxyLoop(service string, proxier *Proxier) {
info, found := proxier.getServiceInfo(service)
if !found {
glog.Errorf("Failed to find service: %s", service)
return
}
for {
info.mu.Lock()
if !info.active {
info.mu.Unlock()
break
}
info.mu.Unlock()
// Block until a connection is made.
inConn, err := tcp.Accept()
if err != nil {
glog.Errorf("Accept failed: %v", err)
continue
}
glog.Infof("Accepted connection from %v to %v", inConn.RemoteAddr(), inConn.LocalAddr())
endpoint, err := proxier.loadBalancer.NextEndpoint(service, inConn.RemoteAddr())
if err != nil {
glog.Errorf("Couldn't find an endpoint for %s %v", service, err)
inConn.Close()
continue
}
glog.Infof("Mapped service %s to endpoint %s", service, endpoint)
// TODO: This could spin up a new goroutine to make the outbound connection,
// and keep accepting inbound traffic.
outConn, err := net.DialTimeout("tcp", endpoint, time.Duration(5)*time.Second)
if err != nil {
// TODO: Try another endpoint?
glog.Errorf("Dial failed: %v", err)
inConn.Close()
continue
}
// Spin up an async copy loop.
proxyTCP(inConn.(*net.TCPConn), outConn.(*net.TCPConn))
}
}
// proxyTCP proxies data bi-directionally between in and out.
func proxyTCP(in, out *net.TCPConn) {
glog.Infof("Creating proxy between %v <-> %v <-> %v <-> %v",
in.RemoteAddr(), in.LocalAddr(), out.LocalAddr(), out.RemoteAddr())
go copyBytes(in, out)
go copyBytes(out, in)
}
func newProxySocket(protocol string, addr string, port int) (proxySocket, error) {
switch strings.ToUpper(protocol) {
case "TCP":
listener, err := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port)))
if err != nil {
return nil, err
}
return &tcpProxySocket{listener}, nil
//TODO: add UDP support
}
return nil, fmt.Errorf("Unknown protocol %q", protocol)
}
// Proxier is a simple proxy for TCP connections between a localhost:lport // Proxier is a simple proxy for TCP connections between a localhost:lport
// and services that provide the actual implementations. // and services that provide the actual implementations.
type Proxier struct { type Proxier struct {
@ -66,14 +145,6 @@ func copyBytes(in, out *net.TCPConn) {
out.CloseWrite() out.CloseWrite()
} }
// proxyConnection proxies data bidirectionally between in and out.
func proxyConnection(in, out *net.TCPConn) {
glog.Infof("Creating proxy between %v <-> %v <-> %v <-> %v",
in.RemoteAddr(), in.LocalAddr(), out.LocalAddr(), out.RemoteAddr())
go copyBytes(in, out)
go copyBytes(out, in)
}
// StopProxy stops the proxy for the named service. // StopProxy stops the proxy for the named service.
func (proxier *Proxier) StopProxy(service string) error { func (proxier *Proxier) StopProxy(service string) error {
// TODO: delete from map here? // TODO: delete from map here?
@ -92,7 +163,7 @@ func (proxier *Proxier) stopProxyInternal(info *serviceInfo) error {
} }
glog.Infof("Removing service: %s", info.name) glog.Infof("Removing service: %s", info.name)
info.active = false info.active = false
return info.listener.Close() return info.socket.Close()
} }
func (proxier *Proxier) getServiceInfo(service string) (*serviceInfo, bool) { func (proxier *Proxier) getServiceInfo(service string) (*serviceInfo, bool) {
@ -109,57 +180,19 @@ func (proxier *Proxier) setServiceInfo(service string, info *serviceInfo) {
proxier.serviceMap[service] = info proxier.serviceMap[service] = info
} }
// AcceptHandler proxies incoming connections for the specified service
// to the load-balanced service endpoints.
func (proxier *Proxier) AcceptHandler(service string, listener net.Listener) {
info, found := proxier.getServiceInfo(service)
if !found {
glog.Errorf("Failed to find service: %s", service)
return
}
for {
info.mu.Lock()
if !info.active {
info.mu.Unlock()
break
}
info.mu.Unlock()
inConn, err := listener.Accept()
if err != nil {
glog.Errorf("Accept failed: %v", err)
continue
}
glog.Infof("Accepted connection from: %v to %v", inConn.RemoteAddr(), inConn.LocalAddr())
endpoint, err := proxier.loadBalancer.NextEndpoint(service, inConn.RemoteAddr())
if err != nil {
glog.Errorf("Couldn't find an endpoint for %s %v", service, err)
inConn.Close()
continue
}
glog.Infof("Mapped service %s to endpoint %s", service, endpoint)
outConn, err := net.DialTimeout("tcp", endpoint, time.Duration(5)*time.Second)
if err != nil {
glog.Errorf("Dial failed: %v", err)
inConn.Close()
continue
}
proxyConnection(inConn.(*net.TCPConn), outConn.(*net.TCPConn))
}
}
// used to globally lock around unused ports. Only used in testing. // used to globally lock around unused ports. Only used in testing.
var unusedPortLock sync.Mutex var unusedPortLock sync.Mutex
// addServiceOnUnusedPort starts listening for a new service, returning the // addServiceOnUnusedPort starts listening for a new service, returning the
// port it's using. For testing on a system with unknown ports used. // port it's using. For testing on a system with unknown ports used.
func (proxier *Proxier) addServiceOnUnusedPort(service string) (string, error) { func (proxier *Proxier) addServiceOnUnusedPort(service, protocol string) (string, error) {
unusedPortLock.Lock() unusedPortLock.Lock()
defer unusedPortLock.Unlock() defer unusedPortLock.Unlock()
l, err := net.Listen("tcp", net.JoinHostPort(proxier.address, "0")) sock, err := newProxySocket(protocol, proxier.address, 0)
if err != nil { if err != nil {
return "", err return "", err
} }
_, port, err := net.SplitHostPort(l.Addr().String()) _, port, err := net.SplitHostPort(sock.Addr().String())
if err != nil { if err != nil {
return "", err return "", err
} }
@ -169,16 +202,17 @@ func (proxier *Proxier) addServiceOnUnusedPort(service string) (string, error) {
} }
proxier.setServiceInfo(service, &serviceInfo{ proxier.setServiceInfo(service, &serviceInfo{
port: portNum, port: portNum,
protocol: protocol,
active: true, active: true,
listener: l, socket: sock,
}) })
proxier.startAccepting(service, l) proxier.startAccepting(service, sock)
return port, nil return port, nil
} }
func (proxier *Proxier) startAccepting(service string, l net.Listener) { func (proxier *Proxier) startAccepting(service string, sock proxySocket) {
glog.Infof("Listening for %s on %s", service, l.Addr().String()) glog.Infof("Listening for %s on %s", service, sock.Addr().String())
go proxier.AcceptHandler(service, l) go sock.ProxyLoop(service, proxier)
} }
// OnUpdate manages the active set of service proxies. // OnUpdate manages the active set of service proxies.
@ -196,18 +230,19 @@ func (proxier *Proxier) OnUpdate(services []api.Service) {
if exists && info.port != service.Port { if exists && info.port != service.Port {
proxier.StopProxy(service.ID) proxier.StopProxy(service.ID)
} }
glog.Infof("Adding a new service %s on port %d", service.ID, service.Port) glog.Infof("Adding a new service %s on %s port %d", service.ID, service.Protocol, service.Port)
listener, err := net.Listen("tcp", net.JoinHostPort(proxier.address, strconv.Itoa(service.Port))) sock, err := newProxySocket(service.Protocol, proxier.address, service.Port)
if err != nil { if err != nil {
glog.Infof("Failed to start listening for %s on %d", service.ID, service.Port) glog.Errorf("Failed to get a socket for %s: %+v", service.ID, err)
continue continue
} }
proxier.setServiceInfo(service.ID, &serviceInfo{ proxier.setServiceInfo(service.ID, &serviceInfo{
port: service.Port, port: service.Port,
protocol: service.Protocol,
active: true, active: true,
listener: listener, socket: sock,
}) })
proxier.startAccepting(service.ID, listener) proxier.startAccepting(service.ID, sock)
} }
proxier.mu.Lock() proxier.mu.Lock()
defer proxier.mu.Unlock() defer proxier.mu.Unlock()

View File

@ -30,7 +30,7 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/api" "github.com/GoogleCloudPlatform/kubernetes/pkg/api"
) )
func waitForClosedPort(p *Proxier, proxyPort string) error { func waitForClosedPortTCP(p *Proxier, proxyPort string) error {
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
_, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", proxyPort)) _, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", proxyPort))
if err != nil { if err != nil {
@ -41,7 +41,7 @@ func waitForClosedPort(p *Proxier, proxyPort string) error {
return fmt.Errorf("port %s still open", proxyPort) return fmt.Errorf("port %s still open", proxyPort)
} }
var port string var tcpServerPort string
func init() { func init() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -52,13 +52,13 @@ func init() {
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to parse: %v", err)) panic(fmt.Sprintf("failed to parse: %v", err))
} }
_, port, err = net.SplitHostPort(u.Host) _, tcpServerPort, err = net.SplitHostPort(u.Host)
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to parse: %v", err)) panic(fmt.Sprintf("failed to parse: %v", err))
} }
} }
func testEchoConnection(t *testing.T, address, port string) { func testEchoTCP(t *testing.T, address, port string) {
path := "aaaaa" path := "aaaaa"
res, err := http.Get("http://" + address + ":" + port + "/" + path) res, err := http.Get("http://" + address + ":" + port + "/" + path)
if err != nil { if err != nil {
@ -74,27 +74,36 @@ func testEchoConnection(t *testing.T, address, port string) {
} }
} }
func TestProxy(t *testing.T) { func TestTCPProxy(t *testing.T) {
lb := NewLoadBalancerRR() lb := NewLoadBalancerRR()
lb.OnUpdate([]api.Endpoints{ lb.OnUpdate([]api.Endpoints{
{JSONBase: api.JSONBase{ID: "echo"}, Endpoints: []string{net.JoinHostPort("127.0.0.1", port)}}}) {
JSONBase: api.JSONBase{ID: "echo"},
Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)},
},
})
p := NewProxier(lb, "127.0.0.1") p := NewProxier(lb, "127.0.0.1")
proxyPort, err := p.addServiceOnUnusedPort("echo") proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP")
if err != nil { if err != nil {
t.Fatalf("error adding new service: %#v", err) t.Fatalf("error adding new service: %#v", err)
} }
testEchoConnection(t, "127.0.0.1", proxyPort) testEchoTCP(t, "127.0.0.1", proxyPort)
} }
func TestProxyStop(t *testing.T) { func TestTCPProxyStop(t *testing.T) {
lb := NewLoadBalancerRR() lb := NewLoadBalancerRR()
lb.OnUpdate([]api.Endpoints{{JSONBase: api.JSONBase{ID: "echo"}, Endpoints: []string{net.JoinHostPort("127.0.0.1", port)}}}) lb.OnUpdate([]api.Endpoints{
{
JSONBase: api.JSONBase{ID: "echo"},
Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)},
},
})
p := NewProxier(lb, "127.0.0.1") p := NewProxier(lb, "127.0.0.1")
proxyPort, err := p.addServiceOnUnusedPort("echo") proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP")
if err != nil { if err != nil {
t.Fatalf("error adding new service: %#v", err) t.Fatalf("error adding new service: %#v", err)
} }
@ -106,18 +115,23 @@ func TestProxyStop(t *testing.T) {
p.StopProxy("echo") p.StopProxy("echo")
// Wait for the port to really close. // Wait for the port to really close.
if err := waitForClosedPort(p, proxyPort); err != nil { if err := waitForClosedPortTCP(p, proxyPort); err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
} }
func TestProxyUpdateDelete(t *testing.T) { func TestTCPProxyUpdateDelete(t *testing.T) {
lb := NewLoadBalancerRR() lb := NewLoadBalancerRR()
lb.OnUpdate([]api.Endpoints{{JSONBase: api.JSONBase{ID: "echo"}, Endpoints: []string{net.JoinHostPort("127.0.0.1", port)}}}) lb.OnUpdate([]api.Endpoints{
{
JSONBase: api.JSONBase{ID: "echo"},
Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)},
},
})
p := NewProxier(lb, "127.0.0.1") p := NewProxier(lb, "127.0.0.1")
proxyPort, err := p.addServiceOnUnusedPort("echo") proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP")
if err != nil { if err != nil {
t.Fatalf("error adding new service: %#v", err) t.Fatalf("error adding new service: %#v", err)
} }
@ -128,18 +142,23 @@ func TestProxyUpdateDelete(t *testing.T) {
conn.Close() conn.Close()
p.OnUpdate([]api.Service{}) p.OnUpdate([]api.Service{})
if err := waitForClosedPort(p, proxyPort); err != nil { if err := waitForClosedPortTCP(p, proxyPort); err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
} }
func TestProxyUpdateDeleteUpdate(t *testing.T) { func TestTCPProxyUpdateDeleteUpdate(t *testing.T) {
lb := NewLoadBalancerRR() lb := NewLoadBalancerRR()
lb.OnUpdate([]api.Endpoints{{JSONBase: api.JSONBase{ID: "echo"}, Endpoints: []string{net.JoinHostPort("127.0.0.1", port)}}}) lb.OnUpdate([]api.Endpoints{
{
JSONBase: api.JSONBase{ID: "echo"},
Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)},
},
})
p := NewProxier(lb, "127.0.0.1") p := NewProxier(lb, "127.0.0.1")
proxyPort, err := p.addServiceOnUnusedPort("echo") proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP")
if err != nil { if err != nil {
t.Fatalf("error adding new service: %#v", err) t.Fatalf("error adding new service: %#v", err)
} }
@ -150,23 +169,28 @@ func TestProxyUpdateDeleteUpdate(t *testing.T) {
conn.Close() conn.Close()
p.OnUpdate([]api.Service{}) p.OnUpdate([]api.Service{})
if err := waitForClosedPort(p, proxyPort); err != nil { if err := waitForClosedPortTCP(p, proxyPort); err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
proxyPortNum, _ := strconv.Atoi(proxyPort) proxyPortNum, _ := strconv.Atoi(proxyPort)
p.OnUpdate([]api.Service{ p.OnUpdate([]api.Service{
{JSONBase: api.JSONBase{ID: "echo"}, Port: proxyPortNum}, {JSONBase: api.JSONBase{ID: "echo"}, Port: proxyPortNum, Protocol: "TCP"},
}) })
testEchoConnection(t, "127.0.0.1", proxyPort) testEchoTCP(t, "127.0.0.1", proxyPort)
} }
func TestProxyUpdatePort(t *testing.T) { func TestTCPProxyUpdatePort(t *testing.T) {
lb := NewLoadBalancerRR() lb := NewLoadBalancerRR()
lb.OnUpdate([]api.Endpoints{{JSONBase: api.JSONBase{ID: "echo"}, Endpoints: []string{net.JoinHostPort("127.0.0.1", port)}}}) lb.OnUpdate([]api.Endpoints{
{
JSONBase: api.JSONBase{ID: "echo"},
Endpoints: []string{net.JoinHostPort("127.0.0.1", tcpServerPort)},
},
})
p := NewProxier(lb, "127.0.0.1") p := NewProxier(lb, "127.0.0.1")
proxyPort, err := p.addServiceOnUnusedPort("echo") proxyPort, err := p.addServiceOnUnusedPort("echo", "TCP")
if err != nil { if err != nil {
t.Fatalf("error adding new service: %#v", err) t.Fatalf("error adding new service: %#v", err)
} }
@ -174,62 +198,28 @@ func TestProxyUpdatePort(t *testing.T) {
// add a new dummy listener in order to get a port that is free // add a new dummy listener in order to get a port that is free
l, _ := net.Listen("tcp", ":0") l, _ := net.Listen("tcp", ":0")
_, newPort, _ := net.SplitHostPort(l.Addr().String()) _, newPort, _ := net.SplitHostPort(l.Addr().String())
portNum, _ := strconv.Atoi(newPort) newPortNum, _ := strconv.Atoi(newPort)
l.Close() l.Close()
// Wait for the socket to actually get free. // Wait for the socket to actually get free.
if err := waitForClosedPort(p, newPort); err != nil { if err := waitForClosedPortTCP(p, newPort); err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
if proxyPort == newPort { if proxyPort == newPort {
t.Errorf("expected difference, got %s %s", newPort, proxyPort) t.Errorf("expected difference, got %s %s", newPort, proxyPort)
} }
p.OnUpdate([]api.Service{ p.OnUpdate([]api.Service{
{JSONBase: api.JSONBase{ID: "echo"}, Port: portNum}, {JSONBase: api.JSONBase{ID: "echo"}, Port: newPortNum, Protocol: "TCP"},
}) })
if err := waitForClosedPort(p, proxyPort); err != nil { if err := waitForClosedPortTCP(p, proxyPort); err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
testEchoConnection(t, "127.0.0.1", newPort) testEchoTCP(t, "127.0.0.1", newPort)
}
func TestProxyUpdatePortLetsGoOfOldPort(t *testing.T) { // Ensure the old port is released and re-usable.
lb := NewLoadBalancerRR() l, err = net.Listen("tcp", net.JoinHostPort("", proxyPort))
lb.OnUpdate([]api.Endpoints{{JSONBase: api.JSONBase{ID: "echo"}, Endpoints: []string{net.JoinHostPort("127.0.0.1", port)}}})
p := NewProxier(lb, "127.0.0.1")
proxyPort, err := p.addServiceOnUnusedPort("echo")
if err != nil { if err != nil {
t.Fatalf("error adding new service: %#v", err) t.Fatalf("can't claim released port: %s", err)
} }
// add a new dummy listener in order to get a port that is free
l, _ := net.Listen("tcp", ":0")
_, newPort, _ := net.SplitHostPort(l.Addr().String())
portNum, _ := strconv.Atoi(newPort)
l.Close() l.Close()
// Wait for the socket to actually get free.
if err := waitForClosedPort(p, newPort); err != nil {
t.Fatalf(err.Error())
}
if proxyPort == newPort {
t.Errorf("expected difference, got %s %s", newPort, proxyPort)
}
p.OnUpdate([]api.Service{
{JSONBase: api.JSONBase{ID: "echo"}, Port: portNum},
})
if err := waitForClosedPort(p, proxyPort); err != nil {
t.Fatalf(err.Error())
}
testEchoConnection(t, "127.0.0.1", newPort)
proxyPortNum, _ := strconv.Atoi(proxyPort)
p.OnUpdate([]api.Service{
{JSONBase: api.JSONBase{ID: "echo"}, Port: proxyPortNum},
})
if err := waitForClosedPort(p, newPort); err != nil {
t.Fatalf(err.Error())
}
testEchoConnection(t, "127.0.0.1", proxyPort)
} }