diff --git a/alpine/packages/proxy/libproxy/udp_encapsulation.go b/alpine/packages/proxy/libproxy/udp_encapsulation.go index 1891aa234..df837a896 100644 --- a/alpine/packages/proxy/libproxy/udp_encapsulation.go +++ b/alpine/packages/proxy/libproxy/udp_encapsulation.go @@ -91,30 +91,40 @@ type udpDatagram struct { } func (u *udpDatagram) Marshal(conn net.Conn) error { + // marshal the variable length header to a temporary buffer + var header bytes.Buffer var length uint16 length = uint16(len(*u.IP)) - if err := binary.Write(conn, binary.LittleEndian, &length); err != nil { + if err := binary.Write(&header, binary.LittleEndian, &length); err != nil { return err } - if err := binary.Write(conn, binary.LittleEndian, &u.IP); err != nil { + if err := binary.Write(&header, binary.LittleEndian, &u.IP); err != nil { return err } - if err := binary.Write(conn, binary.LittleEndian, &u.Port); err != nil { + if err := binary.Write(&header, binary.LittleEndian, &u.Port); err != nil { return err } length = uint16(len(u.Zone)) - if err := binary.Write(conn, binary.LittleEndian, &length); err != nil { + if err := binary.Write(&header, binary.LittleEndian, &length); err != nil { return err } - if err := binary.Write(conn, binary.LittleEndian, &u.Zone); err != nil { + if err := binary.Write(&header, binary.LittleEndian, &u.Zone); err != nil { return nil } length = uint16(len(u.payload)) + if err := binary.Write(&header, binary.LittleEndian, &length); err != nil { + return nil + } + length = uint16(header.Len() + len(u.payload)) if err := binary.Write(conn, binary.LittleEndian, &length); err != nil { return nil } + _, err := io.Copy(conn, &header) + if err != nil { + return err + } payload := bytes.NewBuffer(u.payload) - _, err := io.Copy(conn, payload) + _, err = io.Copy(conn, payload) if err != nil { return err } @@ -123,6 +133,10 @@ func (u *udpDatagram) Marshal(conn net.Conn) error { func (u *udpDatagram) Unmarshal(conn net.Conn) error { var length uint16 + // frame length + if err := binary.Read(conn, binary.LittleEndian, &length); err != nil { + return err + } if err := binary.Read(conn, binary.LittleEndian, &length); err != nil { return err }