diff --git a/README.md b/README.md index 7fbdd4d..1d999b8 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ err = pinger.Run() // Blocks until finished. if err != nil { panic(err) } -stats := pinger.Statistics() // get send/receive/rtt stats +stats := pinger.Statistics() // get send/receive/duplicate/rtt stats ``` Here is an example that emulates the traditional UNIX ping command: @@ -42,6 +42,11 @@ pinger.OnRecv = func(pkt *ping.Packet) { pkt.Nbytes, pkt.IPAddr, pkt.Seq, pkt.Rtt) } +pinger.OnDuplicateRecv = func(pkt *ping.Packet) { + fmt.Printf("%d bytes from %s: icmp_seq=%d time=%v ttl=%v (DUP!)\n", + pkt.Nbytes, pkt.IPAddr, pkt.Seq, pkt.Rtt, pkt.Ttl) +} + pinger.OnFinish = func(stats *ping.Statistics) { fmt.Printf("\n--- %s ping statistics ---\n", stats.Addr) fmt.Printf("%d packets transmitted, %d packets received, %v%% packet loss\n", @@ -58,8 +63,10 @@ if err != nil { ``` It sends ICMP Echo Request packet(s) and waits for an Echo Reply in -response. If it receives a response, it calls the `OnRecv` callback. -When it's finished, it calls the `OnFinish` callback. +response. If it receives a response, it calls the `OnRecv` callback +unless a packet with that sequence number has already been received, +in which case it calls the `OnDuplicateRecv` callback. When it's +finished, it calls the `OnFinish` callback. For a full ping example, see [cmd/ping/ping.go](https://github.com/go-ping/ping/blob/master/cmd/ping/ping.go). diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index a3f028f..beba292 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -68,10 +68,14 @@ func main() { fmt.Printf("%d bytes from %s: icmp_seq=%d time=%v ttl=%v\n", pkt.Nbytes, pkt.IPAddr, pkt.Seq, pkt.Rtt, pkt.Ttl) } + pinger.OnDuplicateRecv = func(pkt *ping.Packet) { + fmt.Printf("%d bytes from %s: icmp_seq=%d time=%v ttl=%v (DUP!)\n", + pkt.Nbytes, pkt.IPAddr, pkt.Seq, pkt.Rtt, pkt.Ttl) + } pinger.OnFinish = func(stats *ping.Statistics) { fmt.Printf("\n--- %s ping statistics ---\n", stats.Addr) - fmt.Printf("%d packets transmitted, %d packets received, %v%% packet loss\n", - stats.PacketsSent, stats.PacketsRecv, stats.PacketLoss) + fmt.Printf("%d packets transmitted, %d packets received, %d duplicates, %v%% packet loss\n", + stats.PacketsSent, stats.PacketsRecv, stats.PacketsRecvDuplicates, stats.PacketLoss) fmt.Printf("round-trip min/avg/max/stddev = %v/%v/%v/%v\n", stats.MinRtt, stats.AvgRtt, stats.MaxRtt, stats.StdDevRtt) } diff --git a/ping.go b/ping.go index 0ed0849..6936c94 100644 --- a/ping.go +++ b/ping.go @@ -62,6 +62,7 @@ import ( "net" "runtime" "sync" + "sync/atomic" "syscall" "time" @@ -84,7 +85,7 @@ var ( // New returns a new Pinger struct pointer. func New(addr string) *Pinger { - r := rand.New(rand.NewSource(time.Now().UnixNano())) + r := rand.New(rand.NewSource(getSeed())) return &Pinger{ Count: -1, Interval: time.Second, @@ -93,13 +94,14 @@ func New(addr string) *Pinger { Timeout: time.Second * 100000, Tracker: r.Int63n(math.MaxInt64), - addr: addr, - done: make(chan bool), - id: r.Intn(math.MaxInt16), - ipaddr: nil, - ipv4: false, - network: "ip", - protocol: "udp", + addr: addr, + done: make(chan bool), + id: r.Intn(math.MaxInt16), + ipaddr: nil, + ipv4: false, + network: "ip", + protocol: "udp", + awaitingSequences: map[int]struct{}{}, } } @@ -132,6 +134,9 @@ type Pinger struct { // Number of packets received PacketsRecv int + // Number of duplicate packets received + PacketsRecvDuplicates int + // If true, keep a record of rtts of all received packets. // Set to false to avoid memory bloat for long running pings. RecordRtts bool @@ -148,6 +153,9 @@ type Pinger struct { // OnFinish is called when Pinger exits OnFinish func(*Statistics) + // OnDuplicateRecv is called when a packet is received that has already been received. + OnDuplicateRecv func(*Packet) + // Size of packet being sent Size int @@ -166,6 +174,8 @@ type Pinger struct { ipv4 bool id int sequence int + // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts + awaitingSequences map[int]struct{} // network is one of "ip", "ip4", or "ip6". network string // protocol is "icmp" or "udp". @@ -208,6 +218,9 @@ type Statistics struct { // PacketsSent is the number of packets sent. PacketsSent int + // PacketsRecvDuplicates is the number of duplicate responses there were to a sent packet. + PacketsRecvDuplicates int + // PacketLoss is the percentage of packets lost. PacketLoss float64 @@ -426,14 +439,15 @@ func (p *Pinger) Statistics() *Statistics { total += rtt } s := Statistics{ - PacketsSent: p.PacketsSent, - PacketsRecv: p.PacketsRecv, - PacketLoss: loss, - Rtts: p.rtts, - Addr: p.addr, - IPAddr: p.ipaddr, - MaxRtt: max, - MinRtt: min, + PacketsSent: p.PacketsSent, + PacketsRecv: p.PacketsRecv, + PacketsRecvDuplicates: p.PacketsRecvDuplicates, + PacketLoss: loss, + Rtts: p.rtts, + Addr: p.addr, + IPAddr: p.ipaddr, + MaxRtt: max, + MinRtt: min, } if len(p.rtts) > 0 { s.AvgRtt = total / time.Duration(len(p.rtts)) @@ -549,6 +563,16 @@ func (p *Pinger) processPacket(recv *packet) error { outPkt.Rtt = receivedAt.Sub(timestamp) outPkt.Seq = pkt.Seq + // If we've already received this sequence, ignore it. + if _, inflight := p.awaitingSequences[pkt.Seq]; !inflight { + p.PacketsRecvDuplicates++ + if p.OnDuplicateRecv != nil { + p.OnDuplicateRecv(outPkt) + } + return nil + } + // remove it from the list of sequences we're waiting for so we don't get duplicates. + delete(p.awaitingSequences, pkt.Seq) p.PacketsRecv++ default: // Very bad, not sure how this can happen @@ -619,7 +643,8 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error { } handler(outPkt) } - + // mark this sequence as in-flight + p.awaitingSequences[p.sequence] = struct{}{} p.PacketsSent++ p.sequence++ break @@ -667,3 +692,10 @@ func intToBytes(tracker int64) []byte { binary.BigEndian.PutUint64(b, uint64(tracker)) return b } + +var seed int64 = time.Now().UnixNano() + +// getSeed returns a goroutine-safe unique seed +func getSeed() int64 { + return atomic.AddInt64(&seed, 1) +} diff --git a/ping_test.go b/ping_test.go index 0616431..6c99a8c 100644 --- a/ping_test.go +++ b/ping_test.go @@ -29,6 +29,7 @@ func TestProcessPacket(t *testing.T) { Seq: pinger.sequence, Data: data, } + pinger.awaitingSequences[pinger.sequence] = struct{}{} msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, @@ -548,3 +549,56 @@ func BenchmarkProcessPacket(b *testing.B) { pinger.processPacket(&pkt) } } + +func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { + pinger := makeTestPinger() + // pinger.protocol = "icmp" // ID is only checked on "icmp" protocol + shouldBe0 := 0 + dups := 0 + + // this function should not be called because the tracker is mismatched + pinger.OnRecv = func(pkt *Packet) { + shouldBe0++ + } + + pinger.OnDuplicateRecv = func(pkt *Packet) { + dups++ + } + + data := append(timeToBytes(time.Now()), intToBytes(pinger.Tracker)...) + if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 { + data = append(data, bytes.Repeat([]byte{1}, remainSize)...) + } + + body := &icmp.Echo{ + ID: 123, + Seq: 0, + Data: data, + } + // register the sequence as sent + pinger.awaitingSequences[0] = struct{}{} + + msg := &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: body, + } + + msgBytes, _ := msg.Marshal(nil) + + pkt := packet{ + nbytes: len(msgBytes), + bytes: msgBytes, + ttl: 24, + } + + err := pinger.processPacket(&pkt) + AssertNoError(t, err) + // receive a duplicate + err = pinger.processPacket(&pkt) + AssertNoError(t, err) + + AssertTrue(t, shouldBe0 == 1) + AssertTrue(t, dups == 1) + AssertTrue(t, pinger.PacketsRecvDuplicates == 1) +}