mirror of
https://github.com/go-ping/ping.git
synced 2025-08-31 12:52:04 +00:00
Make processPacket
more performant (#59)
* Make processPacket more performant * Add more info for debugging to returned error * remove old benchmark * change print statement to error return
This commit is contained in:
105
ping.go
105
ping.go
@@ -44,7 +44,8 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
@@ -60,6 +61,7 @@ import (
|
||||
|
||||
const (
|
||||
timeSliceLength = 8
|
||||
trackerLength = 8
|
||||
protocolICMP = 1
|
||||
protocolIPv6ICMP = 58
|
||||
)
|
||||
@@ -438,6 +440,7 @@ func (p *Pinger) recvICMP(
|
||||
}
|
||||
|
||||
func (p *Pinger) processPacket(recv *packet) error {
|
||||
receivedAt := time.Now()
|
||||
var bytes []byte
|
||||
var proto int
|
||||
if p.ipv4 {
|
||||
@@ -455,7 +458,7 @@ func (p *Pinger) processPacket(recv *packet) error {
|
||||
var m *icmp.Message
|
||||
var err error
|
||||
if m, err = icmp.ParseMessage(proto, bytes[:recv.nbytes]); err != nil {
|
||||
return fmt.Errorf("Error parsing icmp message")
|
||||
return fmt.Errorf("error parsing icmp message: %s", err.Error())
|
||||
}
|
||||
|
||||
if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply {
|
||||
@@ -463,26 +466,6 @@ func (p *Pinger) processPacket(recv *packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
body := m.Body.(*icmp.Echo)
|
||||
// If we are priviledged, we can match icmp.ID
|
||||
if p.network == "ip" {
|
||||
// Check if reply from same ID
|
||||
if body.ID != p.id {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
// If we are not priviledged, we cannot set ID - require kernel ping_table map
|
||||
// need to use contents to identify packet
|
||||
data := IcmpData{}
|
||||
err := json.Unmarshal(body.Data, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if data.Tracker != p.Tracker {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
outPkt := &Packet{
|
||||
Nbytes: recv.nbytes,
|
||||
IPAddr: p.ipaddr,
|
||||
@@ -492,18 +475,33 @@ func (p *Pinger) processPacket(recv *packet) error {
|
||||
|
||||
switch pkt := m.Body.(type) {
|
||||
case *icmp.Echo:
|
||||
data := IcmpData{}
|
||||
err := json.Unmarshal(m.Body.(*icmp.Echo).Data, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
// If we are priviledged, we can match icmp.ID
|
||||
if p.network == "ip" {
|
||||
// Check if reply from same ID
|
||||
if pkt.ID != p.id {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
outPkt.Rtt = time.Since(bytesToTime(data.Bytes))
|
||||
|
||||
if len(pkt.Data) < timeSliceLength+trackerLength {
|
||||
return fmt.Errorf("insufficient data received; got: %d %v",
|
||||
len(pkt.Data), pkt.Data)
|
||||
}
|
||||
|
||||
tracker := bytesToInt(pkt.Data[timeSliceLength:])
|
||||
timestamp := bytesToTime(pkt.Data[:timeSliceLength])
|
||||
|
||||
if tracker != p.Tracker {
|
||||
return nil
|
||||
}
|
||||
|
||||
outPkt.Rtt = receivedAt.Sub(timestamp)
|
||||
outPkt.Seq = pkt.Seq
|
||||
p.PacketsRecv += 1
|
||||
p.PacketsRecv++
|
||||
default:
|
||||
// Very bad, not sure how this can happen
|
||||
return fmt.Errorf("Error, invalid ICMP echo reply. Body type: %T, %s",
|
||||
pkt, pkt)
|
||||
return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt)
|
||||
}
|
||||
|
||||
p.rtts = append(p.rtts, outPkt.Rtt)
|
||||
@@ -515,11 +513,6 @@ func (p *Pinger) processPacket(recv *packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type IcmpData struct {
|
||||
Bytes []byte
|
||||
Tracker int64
|
||||
}
|
||||
|
||||
func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
|
||||
var typ icmp.Type
|
||||
if p.ipv4 {
|
||||
@@ -533,42 +526,41 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
|
||||
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
|
||||
}
|
||||
|
||||
t := timeToBytes(time.Now())
|
||||
if p.Size-timeSliceLength != 0 {
|
||||
t = append(t, byteSliceOfSize(p.Size-timeSliceLength)...)
|
||||
t := append(timeToBytes(time.Now()), intToBytes(p.Tracker)...)
|
||||
if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 {
|
||||
t = append(t, bytes.Repeat([]byte{1}, remainSize)...)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(IcmpData{Bytes: t, Tracker: p.Tracker})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to marshal data %s", err)
|
||||
}
|
||||
body := &icmp.Echo{
|
||||
ID: p.id,
|
||||
Seq: p.sequence,
|
||||
Data: data,
|
||||
Data: t,
|
||||
}
|
||||
|
||||
msg := &icmp.Message{
|
||||
Type: typ,
|
||||
Code: 0,
|
||||
Body: body,
|
||||
}
|
||||
bytes, err := msg.Marshal(nil)
|
||||
|
||||
msgBytes, err := msg.Marshal(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := conn.WriteTo(bytes, dst); err != nil {
|
||||
if _, err := conn.WriteTo(msgBytes, dst); err != nil {
|
||||
if neterr, ok := err.(*net.OpError); ok {
|
||||
if neterr.Err == syscall.ENOBUFS {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
p.PacketsSent += 1
|
||||
p.sequence += 1
|
||||
p.PacketsSent++
|
||||
p.sequence++
|
||||
break
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -582,15 +574,6 @@ func (p *Pinger) listen(netProto string) *icmp.PacketConn {
|
||||
return conn
|
||||
}
|
||||
|
||||
func byteSliceOfSize(n int) []byte {
|
||||
b := make([]byte, n)
|
||||
for i := 0; i < len(b); i++ {
|
||||
b[i] = 1
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func ipv4Payload(recv *packet) []byte {
|
||||
b := recv.bytes
|
||||
if len(b) < ipv4.HeaderLen {
|
||||
@@ -625,3 +608,13 @@ func timeToBytes(t time.Time) []byte {
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func bytesToInt(b []byte) int64 {
|
||||
return int64(binary.BigEndian.Uint64(b))
|
||||
}
|
||||
|
||||
func intToBytes(tracker int64) []byte {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(tracker))
|
||||
return b
|
||||
}
|
||||
|
43
ping_test.go
43
ping_test.go
@@ -1,10 +1,14 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
func TestNewPingerValid(t *testing.T) {
|
||||
@@ -264,3 +268,42 @@ func AssertFalse(t *testing.T, b bool) {
|
||||
t.Errorf("Expected False, got True, Stack:\n%s", string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkProcessPacket(b *testing.B) {
|
||||
pinger, _ := NewPinger("127.0.0.1")
|
||||
|
||||
pinger.ipv4 = true
|
||||
pinger.addr = "127.0.0.1"
|
||||
pinger.network = "ip4:icmp"
|
||||
pinger.id = 123
|
||||
pinger.Tracker = 456
|
||||
|
||||
t := append(timeToBytes(time.Now()), intToBytes(pinger.Tracker)...)
|
||||
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
|
||||
t = append(t, bytes.Repeat([]byte{1}, remainSize)...)
|
||||
}
|
||||
|
||||
body := &icmp.Echo{
|
||||
ID: pinger.id,
|
||||
Seq: pinger.sequence,
|
||||
Data: t,
|
||||
}
|
||||
|
||||
msg := &icmp.Message{
|
||||
Type: ipv4.ICMPTypeEchoReply,
|
||||
Code: 0,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
msgBytes, _ := msg.Marshal(nil)
|
||||
|
||||
pkt := packet{
|
||||
nbytes: len(msgBytes),
|
||||
bytes: msgBytes,
|
||||
ttl: 24,
|
||||
}
|
||||
|
||||
for k := 0; k < b.N; k++ {
|
||||
pinger.processPacket(&pkt)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user