refactor: flatten awaitingSequences map

Not nesting this map anymore has the advantages of making the access logic simpler and avoiding memory leaks due to empty maps being referred to by former UUIDs not being removed from the root list.
This commit is contained in:
Florian Loch 2022-02-25 18:07:41 +01:00
parent b89bb75386
commit 1f8d90182d
2 changed files with 46 additions and 64 deletions

68
ping.go
View File

@ -60,6 +60,7 @@ import (
"math" "math"
"math/rand" "math/rand"
"net" "net"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
@ -87,25 +88,22 @@ var (
// New returns a new Pinger struct pointer. // New returns a new Pinger struct pointer.
func New(addr string) *Pinger { func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed())) r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{ return &Pinger{
Count: -1, Count: -1,
Interval: time.Second, Interval: time.Second,
RecordRtts: true, RecordRtts: true,
Size: timeSliceLength + trackerLength, Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64), Timeout: time.Duration(math.MaxInt64),
addr: addr, addr: addr,
done: make(chan interface{}), done: make(chan interface{}),
id: r.Intn(math.MaxUint16), id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID}, currentUUID: uuid.New(),
ipaddr: nil, ipaddr: nil,
ipv4: false, ipv4: false,
network: "ip", network: "ip",
protocol: "udp", protocol: "udp",
awaitingSequences: firstSequence, awaitingSequences: make(map[string]struct{}),
TTL: 64, TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
} }
@ -189,14 +187,14 @@ type Pinger struct {
ipaddr *net.IPAddr ipaddr *net.IPAddr
addr string addr string
// trackerUUIDs is the list of UUIDs being used for sending packets. // currentUUID is the current UUID used to build unique and recognizable packet payloads
trackerUUIDs []uuid.UUID currentUUID uuid.UUID
ipv4 bool ipv4 bool
id int id int
sequence int sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{} awaitingSequences map[string]struct{}
// network is one of "ip", "ip4", or "ip6". // network is one of "ip", "ip4", or "ip6".
network string network string
// protocol is "icmp" or "udp". // protocol is "icmp" or "udp".
@ -607,27 +605,6 @@ func (p *Pinger) recvICMP(
} }
} }
// getPacketUUID scans the tracking slice for matches.
func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
var packetUUID uuid.UUID
err := packetUUID.UnmarshalBinary(pkt[timeSliceLength : timeSliceLength+trackerLength])
if err != nil {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}
for _, item := range p.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
}
return nil, nil
}
// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
}
func (p *Pinger) processPacket(recv *packet) error { func (p *Pinger) processPacket(recv *packet) error {
receivedAt := time.Now() receivedAt := time.Now()
var proto int var proto int
@ -667,24 +644,29 @@ func (p *Pinger) processPacket(recv *packet) error {
len(pkt.Data), pkt.Data) len(pkt.Data), pkt.Data)
} }
pktUUID, err := p.getPacketUUID(pkt.Data) var pktUUID uuid.UUID
if err != nil || pktUUID == nil { err = pktUUID.UnmarshalBinary(pkt.Data[timeSliceLength : timeSliceLength+trackerLength])
return err if err != nil {
return fmt.Errorf("error decoding tracking UUID: %w", err)
} }
timestamp := bytesToTime(pkt.Data[:timeSliceLength]) timestamp := bytesToTime(pkt.Data[:timeSliceLength])
inPkt.Rtt = receivedAt.Sub(timestamp) inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq inPkt.Seq = pkt.Seq
key := buildLookupKey(pktUUID, pkt.Seq)
// If we've already received this sequence, ignore it. // If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight { if _, inflight := p.awaitingSequences[key]; !inflight {
p.PacketsRecvDuplicates++ p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil { if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt) p.OnDuplicateRecv(inPkt)
} }
return nil return nil
} }
// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq) // remove it from the list of sequences we're waiting for, so we don't get duplicates.
delete(p.awaitingSequences, key)
p.updateStatistics(inPkt) p.updateStatistics(inPkt)
default: default:
// Very bad, not sure how this can happen // Very bad, not sure how this can happen
@ -705,8 +687,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
} }
currentUUID := p.getCurrentTrackerUUID() uuidEncoded, err := p.currentUUID.MarshalBinary()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil { if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err) return fmt.Errorf("unable to marshal UUID binary: %w", err)
} }
@ -753,13 +734,11 @@ func (p *Pinger) sendICMP(conn packetConn) error {
handler(outPkt) handler(outPkt)
} }
// mark this sequence as in-flight // mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{} p.awaitingSequences[buildLookupKey(p.currentUUID, p.sequence)] = struct{}{}
p.PacketsSent++ p.PacketsSent++
p.sequence++ p.sequence++
if p.sequence > 65535 { if p.sequence > 65535 {
newUUID := uuid.New() p.currentUUID = uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.sequence = 0 p.sequence = 0
} }
break break
@ -818,3 +797,8 @@ var seed int64 = time.Now().UnixNano()
func getSeed() int64 { func getSeed() int64 {
return atomic.AddInt64(&seed, 1) return atomic.AddInt64(&seed, 1)
} }
// buildLookupKey builds the key required for lookups on awaitingSequences map
func buildLookupKey(id uuid.UUID, sequenceId int) string {
return string(id[:]) + strconv.Itoa(sequenceId)
}

View File

@ -23,8 +23,7 @@ func TestProcessPacket(t *testing.T) {
shouldBe1++ shouldBe1++
} }
currentUUID := pinger.getCurrentTrackerUUID() uuidEncoded, err := pinger.currentUUID.MarshalBinary()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil { if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
} }
@ -38,7 +37,7 @@ func TestProcessPacket(t *testing.T) {
Seq: pinger.sequence, Seq: pinger.sequence,
Data: data, Data: data,
} }
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{} pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, pinger.sequence)] = struct{}{}
msg := &icmp.Message{ msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply, Type: ipv4.ICMPTypeEchoReply,
@ -67,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) {
shouldBe0++ shouldBe0++
} }
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() currentUUID, err := pinger.currentUUID.MarshalBinary()
if err != nil { if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
} }
@ -110,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) {
shouldBe0++ shouldBe0++
} }
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() currentUUID, err := pinger.currentUUID.MarshalBinary()
if err != nil { if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
} }
@ -190,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) {
pinger := makeTestPinger() pinger := makeTestPinger()
pinger.Size = 4096 pinger.Size = 4096
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() currentUUID, err := pinger.currentUUID.MarshalBinary()
if err != nil { if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
} }
@ -543,7 +542,7 @@ func BenchmarkProcessPacket(b *testing.B) {
pinger.protocol = "ip4:icmp" pinger.protocol = "ip4:icmp"
pinger.id = 123 pinger.id = 123
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() currentUUID, err := pinger.currentUUID.MarshalBinary()
if err != nil { if err != nil {
b.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) b.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
} }
@ -573,7 +572,7 @@ func BenchmarkProcessPacket(b *testing.B) {
} }
for k := 0; k < b.N; k++ { for k := 0; k < b.N; k++ {
pinger.processPacket(&pkt) _ = pinger.processPacket(&pkt)
} }
} }
@ -592,8 +591,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
dups++ dups++
} }
currentUUID := pinger.getCurrentTrackerUUID() uuidEncoded, err := pinger.currentUUID.MarshalBinary()
uuidEncoded, err := currentUUID.MarshalBinary()
if err != nil { if err != nil {
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
} }
@ -608,7 +606,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
Data: data, Data: data,
} }
// register the sequence as sent // register the sequence as sent
pinger.awaitingSequences[currentUUID][0] = struct{}{} pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, 0)] = struct{}{}
msg := &icmp.Message{ msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply, Type: ipv4.ICMPTypeEchoReply,
@ -640,14 +638,14 @@ type testPacketConn struct{}
func (c testPacketConn) Close() error { return nil } func (c testPacketConn) Close() error { return nil }
func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho } func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho }
func (c testPacketConn) SetFlagTTL() error { return nil } func (c testPacketConn) SetFlagTTL() error { return nil }
func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil } func (c testPacketConn) SetReadDeadline(_ time.Time) error { return nil }
func (c testPacketConn) SetTTL(t int) {} func (c testPacketConn) SetTTL(_ int) {}
func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { func (c testPacketConn) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, nil, nil return 0, 0, nil, nil
} }
func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) { func (c testPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) {
return len(b), nil return len(b), nil
} }
@ -655,7 +653,7 @@ type testPacketConnBadWrite struct {
testPacketConn testPacketConn
} }
func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) { func (c testPacketConnBadWrite) WriteTo(_ []byte, _ net.Addr) (int, error) {
return 0, errors.New("bad write") return 0, errors.New("bad write")
} }
@ -684,7 +682,7 @@ type testPacketConnBadRead struct {
testPacketConn testPacketConn
} }
func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { func (c testPacketConnBadRead) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, nil, errors.New("bad read") return 0, 0, nil, errors.New("bad read")
} }