From 1f8d90182d63b0b91cb5aac923a80f6aeeaf6ca7 Mon Sep 17 00:00:00 2001 From: Florian Loch Date: Fri, 25 Feb 2022 18:07:41 +0100 Subject: [PATCH] refactor: flatten awaitingSequences map Not nesting this map anymore has the advantages of making the access logic simpler and avoiding memory leaks due to empty maps being referred to by former UUIDs not being removed from the root list. --- ping.go | 78 +++++++++++++++++++++------------------------------- ping_test.go | 32 ++++++++++----------- 2 files changed, 46 insertions(+), 64 deletions(-) diff --git a/ping.go b/ping.go index ef277ab..2854e0a 100644 --- a/ping.go +++ b/ping.go @@ -60,6 +60,7 @@ import ( "math" "math/rand" "net" + "strconv" "sync" "sync/atomic" "syscall" @@ -87,25 +88,22 @@ var ( // New returns a new Pinger struct pointer. func New(addr string) *Pinger { r := rand.New(rand.NewSource(getSeed())) - firstUUID := uuid.New() - var firstSequence = map[uuid.UUID]map[int]struct{}{} - firstSequence[firstUUID] = make(map[int]struct{}) - return &Pinger{ - Count: -1, - Interval: time.Second, - RecordRtts: true, - Size: timeSliceLength + trackerLength, - Timeout: time.Duration(math.MaxInt64), + return &Pinger{ + Count: -1, + Interval: time.Second, + RecordRtts: true, + Size: timeSliceLength + trackerLength, + Timeout: time.Duration(math.MaxInt64), addr: addr, done: make(chan interface{}), id: r.Intn(math.MaxUint16), - trackerUUIDs: []uuid.UUID{firstUUID}, + currentUUID: uuid.New(), ipaddr: nil, ipv4: false, network: "ip", protocol: "udp", - awaitingSequences: firstSequence, + awaitingSequences: make(map[string]struct{}), TTL: 64, logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, } @@ -189,14 +187,14 @@ type Pinger struct { ipaddr *net.IPAddr addr string - // trackerUUIDs is the list of UUIDs being used for sending packets. - trackerUUIDs []uuid.UUID + // currentUUID is the current UUID used to build unique and recognizable packet payloads + currentUUID uuid.UUID ipv4 bool id int sequence int // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts - awaitingSequences map[uuid.UUID]map[int]struct{} + awaitingSequences map[string]struct{} // network is one of "ip", "ip4", or "ip6". network string // protocol is "icmp" or "udp". @@ -607,27 +605,6 @@ func (p *Pinger) recvICMP( } } -// getPacketUUID scans the tracking slice for matches. -func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) { - var packetUUID uuid.UUID - err := packetUUID.UnmarshalBinary(pkt[timeSliceLength : timeSliceLength+trackerLength]) - if err != nil { - return nil, fmt.Errorf("error decoding tracking UUID: %w", err) - } - - for _, item := range p.trackerUUIDs { - if item == packetUUID { - return &packetUUID, nil - } - } - return nil, nil -} - -// getCurrentTrackerUUID grabs the latest tracker UUID. -func (p *Pinger) getCurrentTrackerUUID() uuid.UUID { - return p.trackerUUIDs[len(p.trackerUUIDs)-1] -} - func (p *Pinger) processPacket(recv *packet) error { receivedAt := time.Now() var proto int @@ -667,24 +644,29 @@ func (p *Pinger) processPacket(recv *packet) error { len(pkt.Data), pkt.Data) } - pktUUID, err := p.getPacketUUID(pkt.Data) - if err != nil || pktUUID == nil { - return err + var pktUUID uuid.UUID + err = pktUUID.UnmarshalBinary(pkt.Data[timeSliceLength : timeSliceLength+trackerLength]) + if err != nil { + return fmt.Errorf("error decoding tracking UUID: %w", err) } timestamp := bytesToTime(pkt.Data[:timeSliceLength]) inPkt.Rtt = receivedAt.Sub(timestamp) inPkt.Seq = pkt.Seq + + key := buildLookupKey(pktUUID, pkt.Seq) + // If we've already received this sequence, ignore it. - if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight { + if _, inflight := p.awaitingSequences[key]; !inflight { p.PacketsRecvDuplicates++ if p.OnDuplicateRecv != nil { p.OnDuplicateRecv(inPkt) } return nil } - // remove it from the list of sequences we're waiting for so we don't get duplicates. - delete(p.awaitingSequences[*pktUUID], pkt.Seq) + + // remove it from the list of sequences we're waiting for, so we don't get duplicates. + delete(p.awaitingSequences, key) p.updateStatistics(inPkt) default: // Very bad, not sure how this can happen @@ -705,8 +687,7 @@ func (p *Pinger) sendICMP(conn packetConn) error { dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} } - currentUUID := p.getCurrentTrackerUUID() - uuidEncoded, err := currentUUID.MarshalBinary() + uuidEncoded, err := p.currentUUID.MarshalBinary() if err != nil { return fmt.Errorf("unable to marshal UUID binary: %w", err) } @@ -753,13 +734,11 @@ func (p *Pinger) sendICMP(conn packetConn) error { handler(outPkt) } // mark this sequence as in-flight - p.awaitingSequences[currentUUID][p.sequence] = struct{}{} + p.awaitingSequences[buildLookupKey(p.currentUUID, p.sequence)] = struct{}{} p.PacketsSent++ p.sequence++ if p.sequence > 65535 { - newUUID := uuid.New() - p.trackerUUIDs = append(p.trackerUUIDs, newUUID) - p.awaitingSequences[newUUID] = make(map[int]struct{}) + p.currentUUID = uuid.New() p.sequence = 0 } break @@ -818,3 +797,8 @@ var seed int64 = time.Now().UnixNano() func getSeed() int64 { return atomic.AddInt64(&seed, 1) } + +// buildLookupKey builds the key required for lookups on awaitingSequences map +func buildLookupKey(id uuid.UUID, sequenceId int) string { + return string(id[:]) + strconv.Itoa(sequenceId) +} diff --git a/ping_test.go b/ping_test.go index b8755e7..073aa6c 100644 --- a/ping_test.go +++ b/ping_test.go @@ -23,8 +23,7 @@ func TestProcessPacket(t *testing.T) { shouldBe1++ } - currentUUID := pinger.getCurrentTrackerUUID() - uuidEncoded, err := currentUUID.MarshalBinary() + uuidEncoded, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -38,7 +37,7 @@ func TestProcessPacket(t *testing.T) { Seq: pinger.sequence, Data: data, } - pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{} + pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, pinger.sequence)] = struct{}{} msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, @@ -67,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) { shouldBe0++ } - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -110,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) { shouldBe0++ } - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -190,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) { pinger := makeTestPinger() pinger.Size = 4096 - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -543,7 +542,7 @@ func BenchmarkProcessPacket(b *testing.B) { pinger.protocol = "ip4:icmp" pinger.id = 123 - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.currentUUID.MarshalBinary() if err != nil { b.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -573,7 +572,7 @@ func BenchmarkProcessPacket(b *testing.B) { } for k := 0; k < b.N; k++ { - pinger.processPacket(&pkt) + _ = pinger.processPacket(&pkt) } } @@ -592,8 +591,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { dups++ } - currentUUID := pinger.getCurrentTrackerUUID() - uuidEncoded, err := currentUUID.MarshalBinary() + uuidEncoded, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -608,7 +606,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { Data: data, } // register the sequence as sent - pinger.awaitingSequences[currentUUID][0] = struct{}{} + pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, 0)] = struct{}{} msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, @@ -640,14 +638,14 @@ type testPacketConn struct{} func (c testPacketConn) Close() error { return nil } func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho } func (c testPacketConn) SetFlagTTL() error { return nil } -func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil } -func (c testPacketConn) SetTTL(t int) {} +func (c testPacketConn) SetReadDeadline(_ time.Time) error { return nil } +func (c testPacketConn) SetTTL(_ int) {} -func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { +func (c testPacketConn) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, nil, nil } -func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) { +func (c testPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) { return len(b), nil } @@ -655,7 +653,7 @@ type testPacketConnBadWrite struct { testPacketConn } -func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) { +func (c testPacketConnBadWrite) WriteTo(_ []byte, _ net.Addr) (int, error) { return 0, errors.New("bad write") } @@ -684,7 +682,7 @@ type testPacketConnBadRead struct { testPacketConn } -func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { +func (c testPacketConnBadRead) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, nil, errors.New("bad read") }