diff --git a/ping.go b/ping.go index b9d64fa..4d1a870 100644 --- a/ping.go +++ b/ping.go @@ -44,7 +44,8 @@ package ping import ( - "encoding/json" + "bytes" + "encoding/binary" "fmt" "math" "math/rand" @@ -60,6 +61,7 @@ import ( const ( timeSliceLength = 8 + trackerLength = 8 protocolICMP = 1 protocolIPv6ICMP = 58 ) @@ -438,6 +440,7 @@ func (p *Pinger) recvICMP( } func (p *Pinger) processPacket(recv *packet) error { + receivedAt := time.Now() var bytes []byte var proto int if p.ipv4 { @@ -455,7 +458,7 @@ func (p *Pinger) processPacket(recv *packet) error { var m *icmp.Message var err error if m, err = icmp.ParseMessage(proto, bytes[:recv.nbytes]); err != nil { - return fmt.Errorf("Error parsing icmp message") + return fmt.Errorf("error parsing icmp message: %s", err.Error()) } if m.Type != ipv4.ICMPTypeEchoReply && m.Type != ipv6.ICMPTypeEchoReply { @@ -463,26 +466,6 @@ func (p *Pinger) processPacket(recv *packet) error { return nil } - body := m.Body.(*icmp.Echo) - // If we are priviledged, we can match icmp.ID - if p.network == "ip" { - // Check if reply from same ID - if body.ID != p.id { - return nil - } - } else { - // If we are not priviledged, we cannot set ID - require kernel ping_table map - // need to use contents to identify packet - data := IcmpData{} - err := json.Unmarshal(body.Data, &data) - if err != nil { - return err - } - if data.Tracker != p.Tracker { - return nil - } - } - outPkt := &Packet{ Nbytes: recv.nbytes, IPAddr: p.ipaddr, @@ -492,18 +475,33 @@ func (p *Pinger) processPacket(recv *packet) error { switch pkt := m.Body.(type) { case *icmp.Echo: - data := IcmpData{} - err := json.Unmarshal(m.Body.(*icmp.Echo).Data, &data) - if err != nil { - return err + + // If we are priviledged, we can match icmp.ID + if p.network == "ip" { + // Check if reply from same ID + if pkt.ID != p.id { + return nil + } } - outPkt.Rtt = time.Since(bytesToTime(data.Bytes)) + + if len(pkt.Data) < timeSliceLength+trackerLength { + return fmt.Errorf("insufficient data received; got: %d %v", + len(pkt.Data), pkt.Data) + } + + tracker := bytesToInt(pkt.Data[timeSliceLength:]) + timestamp := bytesToTime(pkt.Data[:timeSliceLength]) + + if tracker != p.Tracker { + return nil + } + + outPkt.Rtt = receivedAt.Sub(timestamp) outPkt.Seq = pkt.Seq - p.PacketsRecv += 1 + p.PacketsRecv++ default: // Very bad, not sure how this can happen - return fmt.Errorf("Error, invalid ICMP echo reply. Body type: %T, %s", - pkt, pkt) + return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt) } p.rtts = append(p.rtts, outPkt.Rtt) @@ -515,11 +513,6 @@ func (p *Pinger) processPacket(recv *packet) error { return nil } -type IcmpData struct { - Bytes []byte - Tracker int64 -} - func (p *Pinger) sendICMP(conn *icmp.PacketConn) error { var typ icmp.Type if p.ipv4 { @@ -533,42 +526,41 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error { dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} } - t := timeToBytes(time.Now()) - if p.Size-timeSliceLength != 0 { - t = append(t, byteSliceOfSize(p.Size-timeSliceLength)...) + t := append(timeToBytes(time.Now()), intToBytes(p.Tracker)...) + if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 { + t = append(t, bytes.Repeat([]byte{1}, remainSize)...) } - data, err := json.Marshal(IcmpData{Bytes: t, Tracker: p.Tracker}) - if err != nil { - return fmt.Errorf("Unable to marshal data %s", err) - } body := &icmp.Echo{ ID: p.id, Seq: p.sequence, - Data: data, + Data: t, } + msg := &icmp.Message{ Type: typ, Code: 0, Body: body, } - bytes, err := msg.Marshal(nil) + + msgBytes, err := msg.Marshal(nil) if err != nil { return err } for { - if _, err := conn.WriteTo(bytes, dst); err != nil { + if _, err := conn.WriteTo(msgBytes, dst); err != nil { if neterr, ok := err.(*net.OpError); ok { if neterr.Err == syscall.ENOBUFS { continue } } } - p.PacketsSent += 1 - p.sequence += 1 + p.PacketsSent++ + p.sequence++ break } + return nil } @@ -582,15 +574,6 @@ func (p *Pinger) listen(netProto string) *icmp.PacketConn { return conn } -func byteSliceOfSize(n int) []byte { - b := make([]byte, n) - for i := 0; i < len(b); i++ { - b[i] = 1 - } - - return b -} - func ipv4Payload(recv *packet) []byte { b := recv.bytes if len(b) < ipv4.HeaderLen { @@ -625,3 +608,13 @@ func timeToBytes(t time.Time) []byte { } return b } + +func bytesToInt(b []byte) int64 { + return int64(binary.BigEndian.Uint64(b)) +} + +func intToBytes(tracker int64) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(tracker)) + return b +} diff --git a/ping_test.go b/ping_test.go index 709349e..45d5b0a 100644 --- a/ping_test.go +++ b/ping_test.go @@ -1,10 +1,14 @@ package ping import ( + "bytes" "net" "runtime/debug" "testing" "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" ) func TestNewPingerValid(t *testing.T) { @@ -264,3 +268,42 @@ func AssertFalse(t *testing.T, b bool) { t.Errorf("Expected False, got True, Stack:\n%s", string(debug.Stack())) } } + +func BenchmarkProcessPacket(b *testing.B) { + pinger, _ := NewPinger("127.0.0.1") + + pinger.ipv4 = true + pinger.addr = "127.0.0.1" + pinger.network = "ip4:icmp" + pinger.id = 123 + pinger.Tracker = 456 + + t := append(timeToBytes(time.Now()), intToBytes(pinger.Tracker)...) + if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 { + t = append(t, bytes.Repeat([]byte{1}, remainSize)...) + } + + body := &icmp.Echo{ + ID: pinger.id, + Seq: pinger.sequence, + Data: t, + } + + msg := &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: body, + } + + msgBytes, _ := msg.Marshal(nil) + + pkt := packet{ + nbytes: len(msgBytes), + bytes: msgBytes, + ttl: 24, + } + + for k := 0; k < b.N; k++ { + pinger.processPacket(&pkt) + } +}