Fix packet matching; compute statistics as packets are received (#150)

* Fix unprivileged packet matching on Linux
* Compute statistics on-the-fly as packets are received

Signed-off-by: Jean Raby <jean@raby.sh>
Signed-off-by: Charlie Jonas <charlie@charliejonas.co.uk>
Co-authored-by: Charlie Jonas <charlie@charliejonas.co.uk>
This commit is contained in:
Jean Raby
2021-03-11 17:01:31 -05:00
committed by GitHub
parent 25d1413fb7
commit 5f9dd908cc
5 changed files with 96 additions and 61 deletions

74
ping.go
View File

@@ -137,6 +137,13 @@ type Pinger struct {
// Number of duplicate packets received // Number of duplicate packets received
PacketsRecvDuplicates int PacketsRecvDuplicates int
// Round trip time statistics
minRtt time.Duration
maxRtt time.Duration
avgRtt time.Duration
stdDevRtt time.Duration
stddevm2 time.Duration
// If true, keep a record of rtts of all received packets. // If true, keep a record of rtts of all received packets.
// Set to false to avoid memory bloat for long running pings. // Set to false to avoid memory bloat for long running pings.
RecordRtts bool RecordRtts bool
@@ -247,6 +254,27 @@ type Statistics struct {
StdDevRtt time.Duration StdDevRtt time.Duration
} }
func (p *Pinger) updateStatistics(pkt *Packet) {
p.PacketsRecv++
if p.PacketsRecv == 1 || pkt.Rtt < p.minRtt {
p.minRtt = pkt.Rtt
}
if pkt.Rtt > p.maxRtt {
p.maxRtt = pkt.Rtt
}
pktCount := time.Duration(p.PacketsRecv)
// welford's online method for stddev
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
delta := pkt.Rtt - p.avgRtt
p.avgRtt += delta / pktCount
delta2 := pkt.Rtt - p.avgRtt
p.stddevm2 += delta * delta2
p.stdDevRtt = time.Duration(math.Sqrt(float64(p.stddevm2 / pktCount)))
}
// SetIPAddr sets the ip address of the target host. // SetIPAddr sets the ip address of the target host.
func (p *Pinger) SetIPAddr(ipaddr *net.IPAddr) { func (p *Pinger) SetIPAddr(ipaddr *net.IPAddr) {
p.ipv4 = isIPv4(ipaddr.IP) p.ipv4 = isIPv4(ipaddr.IP)
@@ -424,20 +452,6 @@ func (p *Pinger) finish() {
// get it's finished statistics. // get it's finished statistics.
func (p *Pinger) Statistics() *Statistics { func (p *Pinger) Statistics() *Statistics {
loss := float64(p.PacketsSent-p.PacketsRecv) / float64(p.PacketsSent) * 100 loss := float64(p.PacketsSent-p.PacketsRecv) / float64(p.PacketsSent) * 100
var min, max, total time.Duration
if len(p.rtts) > 0 {
min = p.rtts[0]
max = p.rtts[0]
}
for _, rtt := range p.rtts {
if rtt < min {
min = rtt
}
if rtt > max {
max = rtt
}
total += rtt
}
s := Statistics{ s := Statistics{
PacketsSent: p.PacketsSent, PacketsSent: p.PacketsSent,
PacketsRecv: p.PacketsRecv, PacketsRecv: p.PacketsRecv,
@@ -446,17 +460,10 @@ func (p *Pinger) Statistics() *Statistics {
Rtts: p.rtts, Rtts: p.rtts,
Addr: p.addr, Addr: p.addr,
IPAddr: p.ipaddr, IPAddr: p.ipaddr,
MaxRtt: max, MaxRtt: p.maxRtt,
MinRtt: min, MinRtt: p.minRtt,
} AvgRtt: p.avgRtt,
if len(p.rtts) > 0 { StdDevRtt: p.stdDevRtt,
s.AvgRtt = total / time.Duration(len(p.rtts))
var sumsquares time.Duration
for _, rtt := range p.rtts {
sumsquares += (rtt - s.AvgRtt) * (rtt - s.AvgRtt)
}
s.StdDevRtt = time.Duration(math.Sqrt(
float64(sumsquares / time.Duration(len(p.rtts)))))
} }
return &s return &s
} }
@@ -532,7 +539,7 @@ func (p *Pinger) processPacket(recv *packet) error {
return nil return nil
} }
outPkt := &Packet{ inPkt := &Packet{
Nbytes: recv.nbytes, Nbytes: recv.nbytes,
IPAddr: p.ipaddr, IPAddr: p.ipaddr,
Addr: p.addr, Addr: p.addr,
@@ -541,8 +548,7 @@ func (p *Pinger) processPacket(recv *packet) error {
switch pkt := m.Body.(type) { switch pkt := m.Body.(type) {
case *icmp.Echo: case *icmp.Echo:
// Check if the reply has the ID we expect. if !p.matchID(pkt.ID) {
if pkt.ID != p.id {
return nil return nil
} }
@@ -558,30 +564,30 @@ func (p *Pinger) processPacket(recv *packet) error {
return nil return nil
} }
outPkt.Rtt = receivedAt.Sub(timestamp) inPkt.Rtt = receivedAt.Sub(timestamp)
outPkt.Seq = pkt.Seq inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it. // If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[pkt.Seq]; !inflight { if _, inflight := p.awaitingSequences[pkt.Seq]; !inflight {
p.PacketsRecvDuplicates++ p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil { if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(outPkt) p.OnDuplicateRecv(inPkt)
} }
return nil return nil
} }
// remove it from the list of sequences we're waiting for so we don't get duplicates. // remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences, pkt.Seq) delete(p.awaitingSequences, pkt.Seq)
p.PacketsRecv++ p.updateStatistics(inPkt)
default: default:
// Very bad, not sure how this can happen // Very bad, not sure how this can happen
return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt) return fmt.Errorf("invalid ICMP echo reply; type: '%T', '%v'", pkt, pkt)
} }
if p.RecordRtts { if p.RecordRtts {
p.rtts = append(p.rtts, outPkt.Rtt) p.rtts = append(p.rtts, inPkt.Rtt)
} }
handler := p.OnRecv handler := p.OnRecv
if handler != nil { if handler != nil {
handler(outPkt) handler(inPkt)
} }
return nil return nil

View File

@@ -373,19 +373,16 @@ func TestStatisticsSunny(t *testing.T) {
AssertEqualStrings(t, "localhost", p.Addr()) AssertEqualStrings(t, "localhost", p.Addr())
p.PacketsSent = 10 p.PacketsSent = 10
p.PacketsRecv = 10 p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
p.rtts = []time.Duration{ p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000),
time.Duration(1000),
}
stats := p.Statistics() stats := p.Statistics()
if stats.PacketsRecv != 10 { if stats.PacketsRecv != 10 {
@@ -419,19 +416,16 @@ func TestStatisticsLossy(t *testing.T) {
AssertEqualStrings(t, "localhost", p.Addr()) AssertEqualStrings(t, "localhost", p.Addr())
p.PacketsSent = 20 p.PacketsSent = 20
p.PacketsRecv = 10 p.updateStatistics(&Packet{Rtt: time.Duration(10)})
p.rtts = []time.Duration{ p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(10), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(10000)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(10000), p.updateStatistics(&Packet{Rtt: time.Duration(800)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(800), p.updateStatistics(&Packet{Rtt: time.Duration(40)})
time.Duration(1000), p.updateStatistics(&Packet{Rtt: time.Duration(100000)})
time.Duration(40), p.updateStatistics(&Packet{Rtt: time.Duration(1000)})
time.Duration(100000),
time.Duration(1000),
}
stats := p.Statistics() stats := p.Statistics()
if stats.PacketsRecv != 10 { if stats.PacketsRecv != 10 {

19
utils_linux.go Normal file
View File

@@ -0,0 +1,19 @@
// +build linux
package ping
// Returns the length of an ICMP message.
func (p *Pinger) getMessageLength() int {
return p.Size + 8
}
// Attempts to match the ID of an ICMP packet.
func (p *Pinger) matchID(ID int) bool {
// On Linux we can only match ID if we are privileged.
if p.protocol == "icmp" {
if ID != p.id {
return false
}
}
return true
}

View File

@@ -1,4 +1,4 @@
// +build !windows // +build !linux,!windows
package ping package ping
@@ -6,3 +6,11 @@ package ping
func (p *Pinger) getMessageLength() int { func (p *Pinger) getMessageLength() int {
return p.Size + 8 return p.Size + 8
} }
// Attempts to match the ID of an ICMP packet.
func (p *Pinger) matchID(ID int) bool {
if ID != p.id {
return false
}
return true
}

View File

@@ -14,3 +14,11 @@ func (p *Pinger) getMessageLength() int {
} }
return p.Size + 8 + ipv6.HeaderLen return p.Size + 8 + ipv6.HeaderLen
} }
// Attempts to match the ID of an ICMP packet.
func (p *Pinger) matchID(ID int) bool {
if ID != p.id {
return false
}
return true
}