diff --git a/pkg/proxy/proxier.go b/pkg/proxy/proxier.go index 1e835db7e8f..5f30935b120 100644 --- a/pkg/proxy/proxier.go +++ b/pkg/proxy/proxier.go @@ -39,16 +39,22 @@ type serviceInfo struct { active bool } +// How long we wait for a connection to a backend. +const endpointDialTimeout = 5 * time.Second + // Abstraction over TCP/UDP sockets which are proxied. type proxySocket interface { // Addr gets the net.Addr for a proxySocket. Addr() net.Addr - // Close stops the proxySocket from accepting incoming connections. + // Close stops the proxySocket from accepting incoming connections. Each implementation should comment + // on the impact of calling Close while sessions are active. Close() error // ProxyLoop proxies incoming connections for the specified service to the service endpoints. ProxyLoop(service string, proxier *Proxier) } +// tcpProxySocket implements proxySocket. Close() is implemented by net.Listener. When Close() is called, +// no new connections are allowed but existing connections are left untouched. type tcpProxySocket struct { net.Listener } @@ -73,7 +79,7 @@ func (tcp *tcpProxySocket) ProxyLoop(service string, proxier *Proxier) { glog.Errorf("Accept failed: %v", err) continue } - glog.Infof("Accepted connection from %v to %v", inConn.RemoteAddr(), inConn.LocalAddr()) + glog.Infof("Accepted TCP connection from %v to %v", inConn.RemoteAddr(), inConn.LocalAddr()) endpoint, err := proxier.loadBalancer.NextEndpoint(service, inConn.RemoteAddr()) if err != nil { glog.Errorf("Couldn't find an endpoint for %s %v", service, err) @@ -83,7 +89,7 @@ func (tcp *tcpProxySocket) ProxyLoop(service string, proxier *Proxier) { glog.Infof("Mapped service %s to endpoint %s", service, endpoint) // TODO: This could spin up a new goroutine to make the outbound connection, // and keep accepting inbound traffic. - outConn, err := net.DialTimeout("tcp", endpoint, time.Duration(5)*time.Second) + outConn, err := net.DialTimeout("tcp", endpoint, endpointDialTimeout) if err != nil { // TODO: Try another endpoint? glog.Errorf("Dial failed: %v", err) @@ -103,15 +109,160 @@ func proxyTCP(in, out *net.TCPConn) { go copyBytes(out, in) } -func newProxySocket(protocol string, addr string, port int) (proxySocket, error) { +// udpProxySocket implements proxySocket. Close() is implemented by net.UDPConn. When Close() is called, +// no new connections are allowed and existing connections are broken. +// TODO: We could lame-duck this ourselves, if it becomes important. +type udpProxySocket struct { + *net.UDPConn +} + +func (udp *udpProxySocket) Addr() net.Addr { + return udp.LocalAddr() +} + +// Holds all the known UDP clients that have not timed out. +type clientCache struct { + mu sync.Mutex + clients map[string]net.Conn // addr string -> connection +} + +func newClientCache() *clientCache { + return &clientCache{clients: map[string]net.Conn{}} +} + +// How long we leave idle UDP connections open. +const udpIdleTimeout = 1 * time.Minute + +func (udp *udpProxySocket) ProxyLoop(service string, proxier *Proxier) { + info, found := proxier.getServiceInfo(service) + if !found { + glog.Errorf("Failed to find service: %s", service) + return + } + activeClients := newClientCache() + var buffer [4096]byte // 4KiB should be enough for most whole-packets + for { + info.mu.Lock() + if !info.active { + info.mu.Unlock() + break + } + info.mu.Unlock() + + // Block until data arrives. + // TODO: Accumulate a histogram of n or something, to fine tune the buffer size. + n, cliAddr, err := udp.ReadFrom(buffer[0:]) + if err != nil { + if e, ok := err.(net.Error); ok { + if e.Temporary() { + glog.Infof("ReadFrom had a temporary failure: %v", err) + continue + } + } + glog.Errorf("ReadFrom failed, exiting ProxyLoop: %v", err) + break + } + // If this is a client we know already, reuse the connection and goroutine. + activeClients.mu.Lock() + svrConn, found := activeClients.clients[cliAddr.String()] + if !found { + // TODO: This could spin up a new goroutine to make the outbound connection, + // and keep accepting inbound traffic. + glog.Infof("New UDP connection from %s", cliAddr) + endpoint, err := proxier.loadBalancer.NextEndpoint(service, cliAddr) + if err != nil { + glog.Errorf("Couldn't find an endpoint for %s %v", service, err) + activeClients.mu.Unlock() + continue + } + glog.Infof("Mapped service %s to endpoint %s", service, endpoint) + svrConn, err = net.DialTimeout("udp", endpoint, endpointDialTimeout) + if err != nil { + // TODO: Try another endpoint? + glog.Errorf("Dial failed: %v", err) + activeClients.mu.Unlock() + continue + } + activeClients.clients[cliAddr.String()] = svrConn + go udp.proxyClient(cliAddr, svrConn, activeClients) + } + activeClients.mu.Unlock() + // TODO: It would be nice to let the goroutine handle this write, but we don't + // really want to copy the buffer. We could do a pool of buffers or something. + _, err = svrConn.Write(buffer[0:n]) + if err != nil { + if !logTimeout(err) { + glog.Errorf("Write failed: %v", err) + // TODO: Maybe tear down the goroutine for this client/server pair? + } + continue + } + svrConn.SetDeadline(time.Now().Add(udpIdleTimeout)) + if err != nil { + glog.Errorf("SetDeadline failed: %v", err) + continue + } + } +} + +// This function is expected to be called as a goroutine. +func (udp *udpProxySocket) proxyClient(cliAddr net.Addr, svrConn net.Conn, activeClients *clientCache) { + defer svrConn.Close() + var buffer [4096]byte + for { + n, err := svrConn.Read(buffer[0:]) + if err != nil { + if !logTimeout(err) { + glog.Errorf("Read failed: %v", err) + } + break + } + svrConn.SetDeadline(time.Now().Add(udpIdleTimeout)) + if err != nil { + glog.Errorf("SetDeadline failed: %v", err) + break + } + n, err = udp.WriteTo(buffer[0:n], cliAddr) + if err != nil { + if !logTimeout(err) { + glog.Errorf("WriteTo failed: %v", err) + } + break + } + } + activeClients.mu.Lock() + delete(activeClients.clients, cliAddr.String()) + activeClients.mu.Unlock() +} + +func logTimeout(err error) bool { + if e, ok := err.(net.Error); ok { + if e.Timeout() { + glog.Infof("connection to endpoint closed due to inactivity") + return true + } + } + return false +} + +func newProxySocket(protocol string, host string, port int) (proxySocket, error) { switch strings.ToUpper(protocol) { case "TCP": - listener, err := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port))) + listener, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) if err != nil { return nil, err } return &tcpProxySocket{listener}, nil - //TODO: add UDP support + case "UDP": + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(host, strconv.Itoa(port))) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + return &udpProxySocket{conn}, nil } return nil, fmt.Errorf("Unknown protocol %q", protocol) } @@ -162,7 +313,6 @@ func (proxier *Proxier) stopProxyInternal(info *serviceInfo) error { return nil } glog.Infof("Removing service: %s", info.name) - info.active = false return info.socket.Close() } @@ -211,7 +361,7 @@ func (proxier *Proxier) addServiceOnUnusedPort(service, protocol string) (string } func (proxier *Proxier) startAccepting(service string, sock proxySocket) { - glog.Infof("Listening for %s on %s", service, sock.Addr().String()) + glog.Infof("Listening for %s on %s:%s", service, sock.Addr().Network(), sock.Addr().String()) go sock.ProxyLoop(service, proxier) } @@ -224,11 +374,15 @@ func (proxier *Proxier) OnUpdate(services []api.Service) { for _, service := range services { activeServices.Insert(service.ID) info, exists := proxier.getServiceInfo(service.ID) + // TODO: check health of the socket? What if ProxyLoop exited? if exists && info.active && info.port == service.Port { continue } if exists && info.port != service.Port { - proxier.StopProxy(service.ID) + err := proxier.stopProxyInternal(info) + if err != nil { + glog.Errorf("error stopping %s: %v", info.name, err) + } } glog.Infof("Adding a new service %s on %s port %d", service.ID, service.Protocol, service.Port) sock, err := newProxySocket(service.Protocol, proxier.address, service.Port) @@ -248,7 +402,10 @@ func (proxier *Proxier) OnUpdate(services []api.Service) { defer proxier.mu.Unlock() for name, info := range proxier.serviceMap { if !activeServices.Has(name) { - proxier.stopProxyInternal(info) + err := proxier.stopProxyInternal(info) + if err != nil { + glog.Errorf("error stopping %s: %v", info.name, err) + } } } } diff --git a/pkg/proxy/proxier_test.go b/pkg/proxy/proxier_test.go index bb58aaa6f15..03c7550bb78 100644 --- a/pkg/proxy/proxier_test.go +++ b/pkg/proxy/proxier_test.go @@ -32,23 +32,53 @@ import ( func waitForClosedPortTCP(p *Proxier, proxyPort string) error { for i := 0; i < 50; i++ { - _, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", proxyPort)) + conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", proxyPort)) if err != nil { return nil } + conn.Close() + time.Sleep(1 * time.Millisecond) + } + return fmt.Errorf("port %s still open", proxyPort) +} + +func waitForClosedPortUDP(p *Proxier, proxyPort string) error { + for i := 0; i < 50; i++ { + conn, err := net.Dial("udp", net.JoinHostPort("127.0.0.1", proxyPort)) + if err != nil { + return nil + } + conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + // To detect a closed UDP port write, then read. + _, err = conn.Write([]byte("x")) + if err != nil { + if e, ok := err.(net.Error); ok && !e.Timeout() { + return nil + } + } + var buf [4]byte + _, err = conn.Read(buf[0:]) + if err != nil { + if e, ok := err.(net.Error); ok && !e.Timeout() { + return nil + } + } + conn.Close() time.Sleep(1 * time.Millisecond) } return fmt.Errorf("port %s still open", proxyPort) } var tcpServerPort string +var udpServerPort string func init() { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // TCP setup. + tcp := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(r.URL.Path[1:])) })) - u, err := url.Parse(ts.URL) + u, err := url.Parse(tcp.URL) if err != nil { panic(fmt.Sprintf("failed to parse: %v", err)) } @@ -56,6 +86,17 @@ func init() { if err != nil { panic(fmt.Sprintf("failed to parse: %v", err)) } + + // UDP setup. + udp, err := newUDPEchoServer() + if err != nil { + panic(fmt.Sprintf("failed to make a UDP server: %v", err)) + } + _, udpServerPort, err = net.SplitHostPort(udp.LocalAddr().String()) + if err != nil { + panic(fmt.Sprintf("failed to parse: %v", err)) + } + go udp.Loop() } func testEchoTCP(t *testing.T, address, port string) { @@ -74,6 +115,26 @@ func testEchoTCP(t *testing.T, address, port string) { } } +func testEchoUDP(t *testing.T, address, port string) { + data := "abc123" + + conn, err := net.Dial("udp", net.JoinHostPort(address, port)) + if err != nil { + t.Fatalf("error connecting to server: %v", err) + } + if _, err := conn.Write([]byte(data)); err != nil { + t.Fatalf("error sending to server: %v", err) + } + var resp [1024]byte + n, err := conn.Read(resp[0:]) + if err != nil { + t.Errorf("error receiving data: %v", err) + } + if string(resp[0:n]) != data { + t.Errorf("expected: %s, got %s", data, string(resp[0:n])) + } +} + func TestTCPProxy(t *testing.T) { lb := NewLoadBalancerRR() lb.OnUpdate([]api.Endpoints{ @@ -92,6 +153,24 @@ func TestTCPProxy(t *testing.T) { testEchoTCP(t, "127.0.0.1", proxyPort) } +func TestUDPProxy(t *testing.T) { + lb := NewLoadBalancerRR() + lb.OnUpdate([]api.Endpoints{ + { + JSONBase: api.JSONBase{ID: "echo"}, + Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, + }, + }) + + p := NewProxier(lb, "127.0.0.1") + + proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP") + if err != nil { + t.Fatalf("error adding new service: %#v", err) + } + testEchoUDP(t, "127.0.0.1", proxyPort) +} + func TestTCPProxyStop(t *testing.T) { lb := NewLoadBalancerRR() lb.OnUpdate([]api.Endpoints{ @@ -120,6 +199,34 @@ func TestTCPProxyStop(t *testing.T) { } } +func TestUDPProxyStop(t *testing.T) { + lb := NewLoadBalancerRR() + lb.OnUpdate([]api.Endpoints{ + { + JSONBase: api.JSONBase{ID: "echo"}, + Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, + }, + }) + + p := NewProxier(lb, "127.0.0.1") + + proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP") + if err != nil { + t.Fatalf("error adding new service: %#v", err) + } + conn, err := net.Dial("udp", net.JoinHostPort("127.0.0.1", proxyPort)) + if err != nil { + t.Fatalf("error connecting to proxy: %v", err) + } + conn.Close() + + p.StopProxy("echo") + // Wait for the port to really close. + if err := waitForClosedPortUDP(p, proxyPort); err != nil { + t.Fatalf(err.Error()) + } +} + func TestTCPProxyUpdateDelete(t *testing.T) { lb := NewLoadBalancerRR() lb.OnUpdate([]api.Endpoints{ @@ -147,6 +254,33 @@ func TestTCPProxyUpdateDelete(t *testing.T) { } } +func TestUDPProxyUpdateDelete(t *testing.T) { + lb := NewLoadBalancerRR() + lb.OnUpdate([]api.Endpoints{ + { + JSONBase: api.JSONBase{ID: "echo"}, + Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, + }, + }) + + p := NewProxier(lb, "127.0.0.1") + + proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP") + if err != nil { + t.Fatalf("error adding new service: %#v", err) + } + conn, err := net.Dial("udp", net.JoinHostPort("127.0.0.1", proxyPort)) + if err != nil { + t.Fatalf("error connecting to proxy: %v", err) + } + conn.Close() + + p.OnUpdate([]api.Service{}) + if err := waitForClosedPortUDP(p, proxyPort); err != nil { + t.Fatalf(err.Error()) + } +} + func TestTCPProxyUpdateDeleteUpdate(t *testing.T) { lb := NewLoadBalancerRR() lb.OnUpdate([]api.Endpoints{ @@ -179,6 +313,38 @@ func TestTCPProxyUpdateDeleteUpdate(t *testing.T) { testEchoTCP(t, "127.0.0.1", proxyPort) } +func TestUDPProxyUpdateDeleteUpdate(t *testing.T) { + lb := NewLoadBalancerRR() + lb.OnUpdate([]api.Endpoints{ + { + JSONBase: api.JSONBase{ID: "echo"}, + Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, + }, + }) + + p := NewProxier(lb, "127.0.0.1") + + proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP") + if err != nil { + t.Fatalf("error adding new service: %#v", err) + } + conn, err := net.Dial("udp", net.JoinHostPort("127.0.0.1", proxyPort)) + if err != nil { + t.Fatalf("error connecting to proxy: %v", err) + } + conn.Close() + + p.OnUpdate([]api.Service{}) + if err := waitForClosedPortUDP(p, proxyPort); err != nil { + t.Fatalf(err.Error()) + } + proxyPortNum, _ := strconv.Atoi(proxyPort) + p.OnUpdate([]api.Service{ + {JSONBase: api.JSONBase{ID: "echo"}, Port: proxyPortNum, Protocol: "UDP"}, + }) + testEchoUDP(t, "127.0.0.1", proxyPort) +} + func TestTCPProxyUpdatePort(t *testing.T) { lb := NewLoadBalancerRR() lb.OnUpdate([]api.Endpoints{ @@ -223,3 +389,48 @@ func TestTCPProxyUpdatePort(t *testing.T) { } l.Close() } + +func TestUDPProxyUpdatePort(t *testing.T) { + lb := NewLoadBalancerRR() + lb.OnUpdate([]api.Endpoints{ + { + JSONBase: api.JSONBase{ID: "echo"}, + Endpoints: []string{net.JoinHostPort("127.0.0.1", udpServerPort)}, + }, + }) + + p := NewProxier(lb, "127.0.0.1") + + proxyPort, err := p.addServiceOnUnusedPort("echo", "UDP") + if err != nil { + t.Fatalf("error adding new service: %#v", err) + } + + // add a new dummy listener in order to get a port that is free + pc, _ := net.ListenPacket("udp", ":0") + _, newPort, _ := net.SplitHostPort(pc.LocalAddr().String()) + newPortNum, _ := strconv.Atoi(newPort) + pc.Close() + + // Wait for the socket to actually get free. + if err := waitForClosedPortUDP(p, newPort); err != nil { + t.Fatalf(err.Error()) + } + if proxyPort == newPort { + t.Errorf("expected difference, got %s %s", newPort, proxyPort) + } + p.OnUpdate([]api.Service{ + {JSONBase: api.JSONBase{ID: "echo"}, Port: newPortNum, Protocol: "UDP"}, + }) + if err := waitForClosedPortUDP(p, proxyPort); err != nil { + t.Fatalf(err.Error()) + } + testEchoUDP(t, "127.0.0.1", newPort) + + // Ensure the old port is released and re-usable. + pc, err = net.ListenPacket("udp", net.JoinHostPort("", proxyPort)) + if err != nil { + t.Fatalf("can't claim released port: %s", err) + } + pc.Close() +} diff --git a/pkg/proxy/udp_server.go b/pkg/proxy/udp_server.go new file mode 100644 index 00000000000..c5489a49f20 --- /dev/null +++ b/pkg/proxy/udp_server.go @@ -0,0 +1,54 @@ +/* +Copyright 2014 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "fmt" + "net" +) + +// udpEchoServer is a simple echo server in UDP, intended for testing the proxy. +type udpEchoServer struct { + net.PacketConn +} + +func (r *udpEchoServer) Loop() { + var buffer [4096]byte + for { + n, cliAddr, err := r.ReadFrom(buffer[0:]) + if err != nil { + fmt.Printf("ReadFrom failed: %#v\n", err) + continue + } + r.WriteTo(buffer[0:n], cliAddr) + } +} + +func newUDPEchoServer() (*udpEchoServer, error) { + packetconn, err := net.ListenPacket("udp", ":0") + if err != nil { + return nil, err + } + return &udpEchoServer{packetconn}, nil +} + +/* +func main() { + r,_ := newUDPEchoServer() + r.Loop() +} +*/