diff --git a/ping.go b/ping.go index b2c96ba..fe51456 100644 --- a/ping.go +++ b/ping.go @@ -137,6 +137,13 @@ type Pinger struct { // Number of duplicate packets received PacketsRecvDuplicates int + // Round trip time statistics + minRtt time.Duration + maxRtt time.Duration + avgRtt time.Duration + stdDevRtt time.Duration + stddevm2 time.Duration + // If true, keep a record of rtts of all received packets. // Set to false to avoid memory bloat for long running pings. RecordRtts bool @@ -247,6 +254,27 @@ type Statistics struct { StdDevRtt time.Duration } +func (p *Pinger) updateStatistics(pkt *Packet) { + p.PacketsRecv++ + if p.PacketsRecv == 1 || pkt.Rtt < p.minRtt { + p.minRtt = pkt.Rtt + } + + if pkt.Rtt > p.maxRtt { + p.maxRtt = pkt.Rtt + } + + pktCount := time.Duration(p.PacketsRecv) + // welford's online method for stddev + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + delta := pkt.Rtt - p.avgRtt + p.avgRtt += delta / pktCount + delta2 := pkt.Rtt - p.avgRtt + p.stddevm2 += delta * delta2 + + p.stdDevRtt = time.Duration(math.Sqrt(float64(p.stddevm2 / pktCount))) +} + // SetIPAddr sets the ip address of the target host. func (p *Pinger) SetIPAddr(ipaddr *net.IPAddr) { p.ipv4 = isIPv4(ipaddr.IP) @@ -424,20 +452,6 @@ func (p *Pinger) finish() { // get it's finished statistics. func (p *Pinger) Statistics() *Statistics { loss := float64(p.PacketsSent-p.PacketsRecv) / float64(p.PacketsSent) * 100 - var min, max, total time.Duration - if len(p.rtts) > 0 { - min = p.rtts[0] - max = p.rtts[0] - } - for _, rtt := range p.rtts { - if rtt < min { - min = rtt - } - if rtt > max { - max = rtt - } - total += rtt - } s := Statistics{ PacketsSent: p.PacketsSent, PacketsRecv: p.PacketsRecv, @@ -446,17 +460,10 @@ func (p *Pinger) Statistics() *Statistics { Rtts: p.rtts, Addr: p.addr, IPAddr: p.ipaddr, - MaxRtt: max, - MinRtt: min, - } - if len(p.rtts) > 0 { - s.AvgRtt = total / time.Duration(len(p.rtts)) - var sumsquares time.Duration - for _, rtt := range p.rtts { - sumsquares += (rtt - s.AvgRtt) * (rtt - s.AvgRtt) - } - s.StdDevRtt = time.Duration(math.Sqrt( - float64(sumsquares / time.Duration(len(p.rtts))))) + MaxRtt: p.maxRtt, + MinRtt: p.minRtt, + AvgRtt: p.avgRtt, + StdDevRtt: p.stdDevRtt, } return &s } @@ -532,7 +539,7 @@ func (p *Pinger) processPacket(recv *packet) error { return nil } - outPkt := &Packet{ + inPkt := &Packet{ Nbytes: recv.nbytes, IPAddr: p.ipaddr, Addr: p.addr, @@ -541,8 +548,7 @@ func (p *Pinger) processPacket(recv *packet) error { switch pkt := m.Body.(type) { case *icmp.Echo: - // Check if the reply has the ID we expect. - if pkt.ID != p.id { + if !p.matchID(pkt.ID) { return nil } @@ -558,30 +564,30 @@ func (p *Pinger) processPacket(recv *packet) error { return nil } - outPkt.Rtt = receivedAt.Sub(timestamp) - outPkt.Seq = pkt.Seq + inPkt.Rtt = receivedAt.Sub(timestamp) + inPkt.Seq = pkt.Seq // If we've already received this sequence, ignore it. if _, inflight := p.awaitingSequences[pkt.Seq]; !inflight { p.PacketsRecvDuplicates++ if p.OnDuplicateRecv != nil { - p.OnDuplicateRecv(outPkt) + p.OnDuplicateRecv(inPkt) } return nil } // remove it from the list of sequences we're waiting for so we don't get duplicates. delete(p.awaitingSequences, pkt.Seq) - p.PacketsRecv++ + p.updateStatistics(inPkt) default: // Very bad, not sure how this can happen return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt) } if p.RecordRtts { - p.rtts = append(p.rtts, outPkt.Rtt) + p.rtts = append(p.rtts, inPkt.Rtt) } handler := p.OnRecv if handler != nil { - handler(outPkt) + handler(inPkt) } return nil diff --git a/ping_test.go b/ping_test.go index b40981c..14ff23b 100644 --- a/ping_test.go +++ b/ping_test.go @@ -373,19 +373,16 @@ func TestStatisticsSunny(t *testing.T) { AssertEqualStrings(t, "localhost", p.Addr()) p.PacketsSent = 10 - p.PacketsRecv = 10 - p.rtts = []time.Duration{ - time.Duration(1000), - time.Duration(1000), - time.Duration(1000), - time.Duration(1000), - time.Duration(1000), - time.Duration(1000), - time.Duration(1000), - time.Duration(1000), - time.Duration(1000), - time.Duration(1000), - } + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) stats := p.Statistics() if stats.PacketsRecv != 10 { @@ -419,19 +416,16 @@ func TestStatisticsLossy(t *testing.T) { AssertEqualStrings(t, "localhost", p.Addr()) p.PacketsSent = 20 - p.PacketsRecv = 10 - p.rtts = []time.Duration{ - time.Duration(10), - time.Duration(1000), - time.Duration(1000), - time.Duration(10000), - time.Duration(1000), - time.Duration(800), - time.Duration(1000), - time.Duration(40), - time.Duration(100000), - time.Duration(1000), - } + p.updateStatistics(&Packet{Rtt: time.Duration(10)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(10000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(800)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(40)}) + p.updateStatistics(&Packet{Rtt: time.Duration(100000)}) + p.updateStatistics(&Packet{Rtt: time.Duration(1000)}) stats := p.Statistics() if stats.PacketsRecv != 10 { diff --git a/utils_linux.go b/utils_linux.go new file mode 100644 index 0000000..ba785d2 --- /dev/null +++ b/utils_linux.go @@ -0,0 +1,19 @@ +// +build linux + +package ping + +// Returns the length of an ICMP message. +func (p *Pinger) getMessageLength() int { + return p.Size + 8 +} + +// Attempts to match the ID of an ICMP packet. +func (p *Pinger) matchID(ID int) bool { + // On Linux we can only match ID if we are privileged. + if p.protocol == "icmp" { + if ID != p.id { + return false + } + } + return true +} diff --git a/utils_other.go b/utils_other.go index cd73ab2..6ccbe78 100644 --- a/utils_other.go +++ b/utils_other.go @@ -1,4 +1,4 @@ -// +build !windows +// +build !linux,!windows package ping @@ -6,3 +6,11 @@ package ping func (p *Pinger) getMessageLength() int { return p.Size + 8 } + +// Attempts to match the ID of an ICMP packet. +func (p *Pinger) matchID(ID int) bool { + if ID != p.id { + return false + } + return true +} diff --git a/utils_windows.go b/utils_windows.go index 2861f64..ba642bc 100644 --- a/utils_windows.go +++ b/utils_windows.go @@ -14,3 +14,11 @@ func (p *Pinger) getMessageLength() int { } return p.Size + 8 + ipv6.HeaderLen } + +// Attempts to match the ID of an ICMP packet. +func (p *Pinger) matchID(ID int) bool { + if ID != p.id { + return false + } + return true +}