Merge pull request #109 from djs55/fix-udp-proxy

proxy: add support for UDP
This commit is contained in:
Dave Scott 2016-04-22 18:03:40 +01:00
commit da23023a92
5 changed files with 205 additions and 30 deletions

View File

@ -4,8 +4,8 @@ package libproxy
import (
"fmt"
"net"
"github.com/djs55/vsock"
"net"
)
// Proxy defines the behavior of a proxy. It forwards traffic back and forth
@ -25,21 +25,17 @@ type Proxy interface {
BackendAddr() net.Addr
}
// NewProxy creates a Proxy according to the specified frontendAddr and backendAddr.
func NewProxy(frontendAddr, backendAddr net.Addr) (Proxy, error) {
switch frontendAddr.(type) {
func NewProxy(frontendAddr *vsock.VsockAddr, backendAddr net.Addr) (Proxy, error) {
switch backendAddr.(type) {
case *net.UDPAddr:
return NewUDPProxy(frontendAddr.(*net.UDPAddr), backendAddr.(*net.UDPAddr))
case *net.TCPAddr:
listener, err := net.Listen("tcp", frontendAddr.String())
listener, err := vsock.Listen(frontendAddr.Port)
if err != nil {
return nil, err
}
return NewTCPProxy(listener, backendAddr.(*net.TCPAddr))
case *vsock.VsockAddr:
listener, err := vsock.Listen(frontendAddr.(*vsock.VsockAddr).Port)
return NewUDPProxy(frontendAddr, NewUDPListener(listener), backendAddr.(*net.UDPAddr))
case *net.TCPAddr:
listener, err := vsock.Listen(frontendAddr.Port)
if err != nil {
return nil, err
}

View File

@ -0,0 +1,179 @@
package libproxy
import (
"bytes"
"encoding/binary"
"io"
"net"
"sync"
"github.com/Sirupsen/logrus"
)
type udpListener interface {
ReadFromUDP(b []byte) (int, *net.UDPAddr, error)
WriteToUDP(b []byte, addr *net.UDPAddr) (int, error)
Close() error
}
type udpEncapsulator struct {
conn *net.Conn
listener net.Listener
m *sync.Mutex
r *sync.Mutex
w *sync.Mutex
}
func (u *udpEncapsulator) getConn() (net.Conn, error) {
u.m.Lock()
defer u.m.Unlock()
if u.conn != nil {
return *u.conn, nil
}
conn, err := u.listener.Accept()
if err != nil {
logrus.Printf("Failed to accept connection: %#v", err)
return nil, err
}
u.conn = &conn
return conn, nil
}
func (u *udpEncapsulator) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
conn, err := u.getConn()
if err != nil {
return 0, nil, err
}
u.r.Lock()
defer u.r.Unlock()
datagram := &udpDatagram{payload: b}
length, err := datagram.Unmarshal(conn)
if err != nil {
return 0, nil, err
}
udpAddr := net.UDPAddr{IP: *datagram.IP, Port: int(datagram.Port), Zone: datagram.Zone}
return length, &udpAddr, nil
}
func (u *udpEncapsulator) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
conn, err := u.getConn()
if err != nil {
return 0, err
}
u.w.Lock()
defer u.w.Unlock()
datagram := &udpDatagram{payload: b, IP: &addr.IP, Port: uint16(addr.Port), Zone: addr.Zone}
return len(b), datagram.Marshal(conn)
}
func (u *udpEncapsulator) Close() error {
if u.conn != nil {
conn := *u.conn
conn.Close()
}
u.listener.Close()
return nil
}
func NewUDPListener(listener net.Listener) udpListener {
var m sync.Mutex
var r sync.Mutex
var w sync.Mutex
return &udpEncapsulator{
conn: nil,
listener: listener,
m: &m,
r: &r,
w: &w,
}
}
type udpDatagram struct {
payload []byte
IP *net.IP
Port uint16
Zone string
}
func (u *udpDatagram) Marshal(conn net.Conn) error {
// marshal the variable length header to a temporary buffer
var header bytes.Buffer
var length uint16
length = uint16(len(*u.IP))
if err := binary.Write(&header, binary.LittleEndian, &length); err != nil {
return err
}
if err := binary.Write(&header, binary.LittleEndian, u.IP); err != nil {
return err
}
if err := binary.Write(&header, binary.LittleEndian, &u.Port); err != nil {
return err
}
length = uint16(len(u.Zone))
if err := binary.Write(&header, binary.LittleEndian, &length); err != nil {
return err
}
if err := binary.Write(&header, binary.LittleEndian, []byte(u.Zone)); err != nil {
return nil
}
length = uint16(len(u.payload))
if err := binary.Write(&header, binary.LittleEndian, &length); err != nil {
return nil
}
length = uint16(2 + header.Len() + len(u.payload))
if err := binary.Write(conn, binary.LittleEndian, &length); err != nil {
return nil
}
_, err := io.Copy(conn, &header)
if err != nil {
return err
}
payload := bytes.NewBuffer(u.payload)
_, err = io.Copy(conn, payload)
if err != nil {
return err
}
return nil
}
func (u *udpDatagram) Unmarshal(conn net.Conn) (int, error) {
var length uint16
// frame length
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
return 0, err
}
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
return 0, err
}
var IP net.IP
IP = make([]byte, length)
if err := binary.Read(conn, binary.LittleEndian, &IP); err != nil {
return 0, err
}
u.IP = &IP
if err := binary.Read(conn, binary.LittleEndian, &u.Port); err != nil {
return 0, err
}
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
return 0, err
}
Zone := make([]byte, length)
if err := binary.Read(conn, binary.LittleEndian, &Zone); err != nil {
return 0, err
}
u.Zone = string(Zone)
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
return 0, err
}
_, err := io.ReadFull(conn, u.payload[0:length])
if err != nil {
return 0, err
}
return int(length), nil
}

View File

@ -47,22 +47,19 @@ type connTrackMap map[connTrackKey]*net.UDPConn
// interface to handle UDP traffic forwarding between the frontend and backend
// addresses.
type UDPProxy struct {
listener *net.UDPConn
frontendAddr *net.UDPAddr
listener udpListener
frontendAddr net.Addr
backendAddr *net.UDPAddr
connTrackTable connTrackMap
connTrackLock sync.Mutex
}
// NewUDPProxy creates a new UDPProxy.
func NewUDPProxy(frontendAddr, backendAddr *net.UDPAddr) (*UDPProxy, error) {
listener, err := net.ListenUDP("udp", frontendAddr)
if err != nil {
return nil, err
}
func NewUDPProxy(frontendAddr net.Addr, listener udpListener, backendAddr *net.UDPAddr) (*UDPProxy, error) {
return &UDPProxy{
listener: listener,
frontendAddr: listener.LocalAddr().(*net.UDPAddr),
frontendAddr: frontendAddr,
backendAddr: backendAddr,
connTrackTable: make(connTrackMap),
}, nil
@ -112,7 +109,7 @@ func (proxy *UDPProxy) Run() {
// ECONNREFUSED like Read do (see comment in
// UDPProxy.replyLoop)
if !isClosedError(err) {
logrus.Printf("Stopping proxy on udp/%v for udp/%v (%s)", proxy.frontendAddr, proxy.backendAddr, err)
logrus.Printf("Stopping proxy on %v for udp/%v (%s)", proxy.frontendAddr, proxy.backendAddr, err)
}
break
}

View File

@ -14,17 +14,17 @@ import (
func main() {
host, port, container := parseHostContainerAddrs()
p, err := libproxy.NewProxy(&vsock.VsockAddr{Port: uint(port)}, container)
if err != nil {
sendError(err)
}
ctl, err := exposePort(host, port)
if err != nil {
sendError(err)
}
p, err := libproxy.NewProxy(&vsock.VsockAddr{Port: uint(port)}, container)
if err != nil {
sendError(err)
}
go handleStopSignals(p)
// TODO: avoid this line if we are running in a TTY
sendOK()
p.Run()
ctl.Close() // ensure ctl remains alive and un-GCed until here
@ -32,7 +32,7 @@ func main() {
}
func exposePort(host net.Addr, port int) (*os.File, error) {
name := host.String()
name := host.Network() + ":" + host.String()
log.Printf("exposePort %s\n", name)
err := os.Mkdir("/port/"+name, 0)
if err != nil {

View File

@ -28,8 +28,11 @@ func sendOK() {
f.Close()
}
// Map dynamic ports onto vsock ports over this offset
var vSockPortOffset = 0x10000
// Map dynamic TCP ports onto vsock ports over this offset
var vSockTCPPortOffset = 0x10000
// Map dynamic UDP ports onto vsock ports over this offset
var vSockUDPPortOffset = 0x20000
// From docker/libnetwork/portmapper/proxy.go:
@ -49,11 +52,11 @@ func parseHostContainerAddrs() (host net.Addr, port int, container net.Addr) {
switch *proto {
case "tcp":
host = &net.TCPAddr{IP: net.ParseIP(*hostIP), Port: *hostPort}
port = vSockPortOffset + *hostPort
port = vSockTCPPortOffset + *hostPort
container = &net.TCPAddr{IP: net.ParseIP(*containerIP), Port: *containerPort}
case "udp":
host = &net.UDPAddr{IP: net.ParseIP(*hostIP), Port: *hostPort}
port = vSockPortOffset + *hostPort
port = vSockUDPPortOffset + *hostPort
container = &net.UDPAddr{IP: net.ParseIP(*containerIP), Port: *containerPort}
default:
log.Fatalf("unsupported protocol %s", *proto)