From 7e850a144486857bf68b391843f95954966a98e3 Mon Sep 17 00:00:00 2001 From: Florian Loch Date: Sun, 27 Feb 2022 22:06:48 +0100 Subject: [PATCH] feature: implement packet timeout to detect lost packets --- ping.go | 232 ++++++++++++++++++++++++++++++++++++++------------- ping_test.go | 6 +- 2 files changed, 177 insertions(+), 61 deletions(-) diff --git a/ping.go b/ping.go index 2854e0a..4ef7ef0 100644 --- a/ping.go +++ b/ping.go @@ -103,9 +103,10 @@ func New(addr string) *Pinger { ipv4: false, network: "ip", protocol: "udp", - awaitingSequences: make(map[string]struct{}), + awaitingSequences: make(map[string]time.Time), TTL: 64, logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, + PacketTimeout: 100 * time.Millisecond, } } @@ -122,6 +123,7 @@ type Pinger struct { // Timeout specifies a timeout before ping exits, regardless of how many // packets have been received. + // This is not to be confused with PacketTimeout. Timeout time.Duration // Count tells pinger to stop after sending (and receiving) Count echo @@ -141,6 +143,9 @@ type Pinger struct { // Number of duplicate packets received PacketsRecvDuplicates int + // PacketsLost counts packets that have not been answered (within PacketTimeout) + PacketsLost int + // Round trip time statistics minRtt time.Duration maxRtt time.Duration @@ -171,6 +176,10 @@ type Pinger struct { // OnDuplicateRecv is called when a packet is received that has already been received. OnDuplicateRecv func(*Packet) + // OnLost is called when Pinger considers a packet lost. + // This will happen when there is no matching response for >= PacketTimeout. + OnLost func(usedUUID uuid.UUID, sequenceID int, noResponseAfter time.Duration) + // Size of packet being sent Size int @@ -193,8 +202,11 @@ type Pinger struct { ipv4 bool id int sequence int - // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts - awaitingSequences map[string]struct{} + + // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts. + // This map does not need synchronization/locking because it is only ever accessed from one goroutine. + awaitingSequences map[string]time.Time + // network is one of "ip", "ip4", or "ip6". network string // protocol is "icmp" or "udp". @@ -202,7 +214,17 @@ type Pinger struct { logger Logger + // TTL is the number of hops a ping packet is allowed before being discarded. + // With IPv4 it maps to the TTL header field, with IPv6 to the Hop Limit one. + // TTL has to be >=1 and <=255 as both header fields are limited to 8 bit and a hop limit of 0 is not valid. + // TODO: Perhaps this should be enforced by changing the type to uin8 or by hiding the field behind a setter? TTL int + + // PacketTimeout is the duration after which a package will be considered lost. + // Defaults to math.MaxInt64 - which practically means it will never be considered lost. + // Checking whether a package is lost will be performed every PacketTimeout. + // If a response arrives after PacketTimeout but before the check gets performed it will NOT be considered lost. + PacketTimeout time.Duration } type packet struct { @@ -250,6 +272,9 @@ type Statistics struct { // PacketLoss is the percentage of packets lost. PacketLoss float64 + // PacketsLost is the actual amount of lost packets + PacketsLost int + // IPAddr is the address of the host being pinged. IPAddr *net.IPAddr @@ -405,12 +430,17 @@ func (p *Pinger) Run() error { if p.Size < timeSliceLength+trackerLength { return fmt.Errorf("size %d is less than minimum required size %d", p.Size, timeSliceLength+trackerLength) } + + if p.TTL < 1 || p.TTL > 255 { + return fmt.Errorf("TTL %d out of range; has to be >= 1 and <= 255", p.TTL) + } + if p.ipaddr == nil { - err = p.Resolve() - } - if err != nil { - return err + if err := p.Resolve(); err != nil { + return err + } } + if conn, err = p.listen(); err != nil { return err } @@ -459,6 +489,17 @@ func (p *Pinger) runLoop( timeout := time.NewTicker(p.Timeout) interval := time.NewTicker(p.Interval) + + var intervalLostPacketsCheck <-chan time.Time + + // In case it is zero NewTicker would panic, furthermore 0 is defined as "packets never timeout" + if p.PacketTimeout > 0 { + t := time.NewTicker(p.PacketTimeout) + defer t.Stop() + + intervalLostPacketsCheck = t.C + } + defer func() { interval.Stop() timeout.Stop() @@ -493,8 +534,12 @@ func (p *Pinger) runLoop( // FIXME: this logs as FATAL but continues logger.Fatalf("sending packet: %s", err) } + + case <-intervalLostPacketsCheck: + p.checkForLostPackets() } - if p.Count > 0 && p.PacketsRecv >= p.Count { + + if p.Count > 0 && p.PacketsRecv+p.PacketsLost >= p.Count { return nil } } @@ -529,13 +574,13 @@ func (p *Pinger) finish() { func (p *Pinger) Statistics() *Statistics { p.statsMu.RLock() defer p.statsMu.RUnlock() - sent := p.PacketsSent - loss := float64(sent-p.PacketsRecv) / float64(sent) * 100 - s := Statistics{ - PacketsSent: sent, + + return &Statistics{ + PacketsSent: p.PacketsSent, PacketsRecv: p.PacketsRecv, PacketsRecvDuplicates: p.PacketsRecvDuplicates, - PacketLoss: loss, + PacketLoss: float64(p.PacketsLost) / float64(p.PacketsSent) * 100, + PacketsLost: p.PacketsLost, Rtts: p.rtts, Addr: p.addr, IPAddr: p.ipaddr, @@ -544,7 +589,6 @@ func (p *Pinger) Statistics() *Statistics { AvgRtt: p.avgRtt, StdDevRtt: p.stdDevRtt, } - return &s } type expBackoff struct { @@ -644,24 +688,31 @@ func (p *Pinger) processPacket(recv *packet) error { len(pkt.Data), pkt.Data) } - var pktUUID uuid.UUID - err = pktUUID.UnmarshalBinary(pkt.Data[timeSliceLength : timeSliceLength+trackerLength]) + pktUUID, err := uuid.FromBytes(pkt.Data[timeSliceLength : timeSliceLength+trackerLength]) if err != nil { return fmt.Errorf("error decoding tracking UUID: %w", err) } - timestamp := bytesToTime(pkt.Data[:timeSliceLength]) - inPkt.Rtt = receivedAt.Sub(timestamp) + sentAt := bytesToTime(pkt.Data[:timeSliceLength]) + inPkt.Rtt = receivedAt.Sub(sentAt) inPkt.Seq = pkt.Seq key := buildLookupKey(pktUUID, pkt.Seq) // If we've already received this sequence, ignore it. if _, inflight := p.awaitingSequences[key]; !inflight { + // Check whether this isn't a duplicate but a response that has been declared lost already and therefore + // isn't present in awaitingSequences anymore + // If PacketTimeout is set to 0, packets shall never time out. + if p.PacketTimeout != 0 && receivedAt.Sub(sentAt) >= p.PacketTimeout { + return nil + } + p.PacketsRecvDuplicates++ if p.OnDuplicateRecv != nil { p.OnDuplicateRecv(inPkt) } + return nil } @@ -691,59 +742,74 @@ func (p *Pinger) sendICMP(conn packetConn) error { if err != nil { return fmt.Errorf("unable to marshal UUID binary: %w", err) } - t := append(timeToBytes(time.Now()), uuidEncoded...) - if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 { - t = append(t, bytes.Repeat([]byte{1}, remainSize)...) - } - body := &icmp.Echo{ - ID: p.id, - Seq: p.sequence, - Data: t, - } - - msg := &icmp.Message{ - Type: conn.ICMPRequestType(), - Code: 0, - Body: body, - } - - msgBytes, err := msg.Marshal(nil) - if err != nil { - return err - } + var ( + sentAt time.Time + msgBytes []byte + ) for { + sentAt = time.Now() + + t := append(timeToBytes(sentAt), uuidEncoded...) + + if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 { + t = append(t, bytes.Repeat([]byte{1}, remainSize)...) + } + + body := &icmp.Echo{ + ID: p.id, + Seq: p.sequence, + Data: t, + } + + msg := &icmp.Message{ + Type: conn.ICMPRequestType(), + Code: 0, + Body: body, + } + + msgBytes, err = msg.Marshal(nil) + if err != nil { + return err + } + if _, err := conn.WriteTo(msgBytes, dst); err != nil { if neterr, ok := err.(*net.OpError); ok { if neterr.Err == syscall.ENOBUFS { + // Slow down the busy loop + time.Sleep(2 * time.Millisecond) + continue } } + return err } - handler := p.OnSend - if handler != nil { - outPkt := &Packet{ - Nbytes: len(msgBytes), - IPAddr: p.ipaddr, - Addr: p.addr, - Seq: p.sequence, - ID: p.id, - } - handler(outPkt) - } - // mark this sequence as in-flight - p.awaitingSequences[buildLookupKey(p.currentUUID, p.sequence)] = struct{}{} - p.PacketsSent++ - p.sequence++ - if p.sequence > 65535 { - p.currentUUID = uuid.New() - p.sequence = 0 - } + break } + handler := p.OnSend + if handler != nil { + handler(&Packet{ + Nbytes: len(msgBytes), + IPAddr: p.ipaddr, + Addr: p.addr, + Seq: p.sequence, + ID: p.id, + }) + } + + // mark this sequence as in-flight + p.awaitingSequences[buildLookupKey(p.currentUUID, p.sequence)] = sentAt + p.PacketsSent++ + p.sequence++ + if p.sequence > 65535 { + p.currentUUID = uuid.New() + p.sequence = 0 + } + return nil } @@ -770,12 +836,41 @@ func (p *Pinger) listen() (packetConn, error) { return conn, nil } +func (p *Pinger) checkForLostPackets() { + if p.PacketTimeout == 0 { + // Packets shall not time out + return + } + + now := time.Now() + + for k, sentAt := range p.awaitingSequences { + if delta := now.Sub(sentAt); delta >= p.PacketTimeout { + delete(p.awaitingSequences, k) + + p.statsMu.Lock() + p.PacketsLost++ + p.statsMu.Unlock() + + if p.OnLost != nil { + usedUUID, sequenceID, err := parseLookupKey(k) + // This should never happen as all keys used in the map are build using buildLookupKey() + if err != nil { + p.logger.Errorf("invalid lookup key %q: %s", k, err) + } + + p.OnLost(usedUUID, sequenceID, delta) + } + } + } +} + func bytesToTime(b []byte) time.Time { var nsec int64 for i := uint8(0); i < 8; i++ { nsec += int64(b[i]) << ((7 - i) * 8) } - return time.Unix(nsec/1000000000, nsec%1000000000) + return time.Unix(nsec/1_000_000_000, nsec%1_000_000_000) } func isIPv4(ip net.IP) bool { @@ -802,3 +897,24 @@ func getSeed() int64 { func buildLookupKey(id uuid.UUID, sequenceId int) string { return string(id[:]) + strconv.Itoa(sequenceId) } + +// parseLookupKey retries UUID and sequence ID from a lookup key build with buildLookupKey +func parseLookupKey(key string) (uuid.UUID, int, error) { + // 16 bytes for the UUID and at least one byte for the sequence ID + if len(key) < 17 { + return uuid.UUID{}, 0, fmt.Errorf("lookup key to short, expected length to be at least 17 but was %d", len(key)) + } + + // The first 16 bytes represent the UUID + readUUID, err := uuid.FromBytes([]byte(key[:16])) + if err != nil { + return uuid.UUID{}, 0, fmt.Errorf("unmarshalling UUID from lookup key: %w", err) + } + + sequenceID, err := strconv.Atoi(key[16:]) + if err != nil { + return uuid.UUID{}, 0, fmt.Errorf("reading sequence ID from lookup key: %w", err) + } + + return readUUID, sequenceID, nil +} diff --git a/ping_test.go b/ping_test.go index 073aa6c..8708281 100644 --- a/ping_test.go +++ b/ping_test.go @@ -37,7 +37,7 @@ func TestProcessPacket(t *testing.T) { Seq: pinger.sequence, Data: data, } - pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, pinger.sequence)] = struct{}{} + pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, pinger.sequence)] = time.Now() msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, @@ -458,7 +458,7 @@ func TestStatisticsLossy(t *testing.T) { if stats.PacketsSent != 20 { t.Errorf("Expected %v, got %v", 20, stats.PacketsSent) } - if stats.PacketLoss != 50 { + if stats.PacketLoss != 0 { t.Errorf("Expected %v, got %v", 50, stats.PacketLoss) } if stats.MinRtt != time.Duration(10) { @@ -606,7 +606,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { Data: data, } // register the sequence as sent - pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, 0)] = struct{}{} + pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, 0)] = time.Now() msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply,