refactor: flatten awaitingSequences map

Not nesting this map anymore has the advantages of making the access logic simpler and avoiding memory leaks due to empty maps being referred to by former UUIDs not being removed from the root list.
This commit is contained in:
Florian Loch 2022-02-25 18:07:41 +01:00
parent b89bb75386
commit 1f8d90182d
2 changed files with 46 additions and 64 deletions

78
ping.go
View File

@ -60,6 +60,7 @@ import (
"math"
"math/rand"
"net"
"strconv"
"sync"
"sync/atomic"
"syscall"
@ -87,25 +88,22 @@ var (
// New returns a new Pinger struct pointer.
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{
Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),
return &Pinger{
Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),
addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID},
currentUUID: uuid.New(),
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
awaitingSequences: firstSequence,
awaitingSequences: make(map[string]struct{}),
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
@ -189,14 +187,14 @@ type Pinger struct {
ipaddr *net.IPAddr
addr string
// trackerUUIDs is the list of UUIDs being used for sending packets.
trackerUUIDs []uuid.UUID
// currentUUID is the current UUID used to build unique and recognizable packet payloads
currentUUID uuid.UUID
ipv4 bool
id int
sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{}
awaitingSequences map[string]struct{}
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
@ -607,27 +605,6 @@ func (p *Pinger) recvICMP(
}
}
// getPacketUUID scans the tracking slice for matches.
func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
var packetUUID uuid.UUID
err := packetUUID.UnmarshalBinary(pkt[timeSliceLength : timeSliceLength+trackerLength])
if err != nil {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}
for _, item := range p.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
}
return nil, nil
}
// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
}
func (p *Pinger) processPacket(recv *packet) error {
receivedAt := time.Now()
var proto int
@ -667,24 +644,29 @@ func (p *Pinger) processPacket(recv *packet) error {
len(pkt.Data), pkt.Data)
}
pktUUID, err := p.getPacketUUID(pkt.Data)
if err != nil || pktUUID == nil {
return err
var pktUUID uuid.UUID
err = pktUUID.UnmarshalBinary(pkt.Data[timeSliceLength : timeSliceLength+trackerLength])
if err != nil {
return fmt.Errorf("error decoding tracking UUID: %w", err)
}
timestamp := bytesToTime(pkt.Data[:timeSliceLength])
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
key := buildLookupKey(pktUUID, pkt.Seq)
// If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
if _, inflight := p.awaitingSequences[key]; !inflight {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}
// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
// remove it from the list of sequences we're waiting for, so we don't get duplicates.
delete(p.awaitingSequences, key)
p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
@ -705,8 +687,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
}
currentUUID := p.getCurrentTrackerUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
uuidEncoded, err := p.currentUUID.MarshalBinary()
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
}
@ -753,13 +734,11 @@ func (p *Pinger) sendICMP(conn packetConn) error {
handler(outPkt)
}
// mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
p.awaitingSequences[buildLookupKey(p.currentUUID, p.sequence)] = struct{}{}
p.PacketsSent++
p.sequence++
if p.sequence > 65535 {
newUUID := uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.currentUUID = uuid.New()
p.sequence = 0
}
break
@ -818,3 +797,8 @@ var seed int64 = time.Now().UnixNano()
func getSeed() int64 {
return atomic.AddInt64(&seed, 1)
}
// buildLookupKey builds the key required for lookups on awaitingSequences map
func buildLookupKey(id uuid.UUID, sequenceId int) string {
return string(id[:]) + strconv.Itoa(sequenceId)
}

View File

@ -23,8 +23,7 @@ func TestProcessPacket(t *testing.T) {
shouldBe1++
}
currentUUID := pinger.getCurrentTrackerUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
uuidEncoded, err := pinger.currentUUID.MarshalBinary()
if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
}
@ -38,7 +37,7 @@ func TestProcessPacket(t *testing.T) {
Seq: pinger.sequence,
Data: data,
}
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, pinger.sequence)] = struct{}{}
msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
@ -67,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) {
shouldBe0++
}
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.currentUUID.MarshalBinary()
if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
}
@ -110,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) {
shouldBe0++
}
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.currentUUID.MarshalBinary()
if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
}
@ -190,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) {
pinger := makeTestPinger()
pinger.Size = 4096
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.currentUUID.MarshalBinary()
if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
}
@ -543,7 +542,7 @@ func BenchmarkProcessPacket(b *testing.B) {
pinger.protocol = "ip4:icmp"
pinger.id = 123
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
currentUUID, err := pinger.currentUUID.MarshalBinary()
if err != nil {
b.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
}
@ -573,7 +572,7 @@ func BenchmarkProcessPacket(b *testing.B) {
}
for k := 0; k < b.N; k++ {
pinger.processPacket(&pkt)
_ = pinger.processPacket(&pkt)
}
}
@ -592,8 +591,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
dups++
}
currentUUID := pinger.getCurrentTrackerUUID()
uuidEncoded, err := currentUUID.MarshalBinary()
uuidEncoded, err := pinger.currentUUID.MarshalBinary()
if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
}
@ -608,7 +606,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
Data: data,
}
// register the sequence as sent
pinger.awaitingSequences[currentUUID][0] = struct{}{}
pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, 0)] = struct{}{}
msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
@ -640,14 +638,14 @@ type testPacketConn struct{}
func (c testPacketConn) Close() error { return nil }
func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho }
func (c testPacketConn) SetFlagTTL() error { return nil }
func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (c testPacketConn) SetTTL(t int) {}
func (c testPacketConn) SetReadDeadline(_ time.Time) error { return nil }
func (c testPacketConn) SetTTL(_ int) {}
func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
func (c testPacketConn) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, nil, nil
}
func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) {
func (c testPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) {
return len(b), nil
}
@ -655,7 +653,7 @@ type testPacketConnBadWrite struct {
testPacketConn
}
func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) {
func (c testPacketConnBadWrite) WriteTo(_ []byte, _ net.Addr) (int, error) {
return 0, errors.New("bad write")
}
@ -684,7 +682,7 @@ type testPacketConnBadRead struct {
testPacketConn
}
func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
func (c testPacketConnBadRead) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, nil, errors.New("bad read")
}