mirror of
https://github.com/go-ping/ping.git
synced 2025-07-04 17:56:18 +00:00
refactor: flatten awaitingSequences map
Not nesting this map anymore has the advantages of making the access logic simpler and avoiding memory leaks due to empty maps being referred to by former UUIDs not being removed from the root list.
This commit is contained in:
parent
b89bb75386
commit
1f8d90182d
68
ping.go
68
ping.go
@ -60,6 +60,7 @@ import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
@ -87,25 +88,22 @@ var (
|
||||
// New returns a new Pinger struct pointer.
|
||||
func New(addr string) *Pinger {
|
||||
r := rand.New(rand.NewSource(getSeed()))
|
||||
firstUUID := uuid.New()
|
||||
var firstSequence = map[uuid.UUID]map[int]struct{}{}
|
||||
firstSequence[firstUUID] = make(map[int]struct{})
|
||||
|
||||
return &Pinger{
|
||||
Count: -1,
|
||||
Interval: time.Second,
|
||||
RecordRtts: true,
|
||||
Size: timeSliceLength + trackerLength,
|
||||
Timeout: time.Duration(math.MaxInt64),
|
||||
|
||||
addr: addr,
|
||||
done: make(chan interface{}),
|
||||
id: r.Intn(math.MaxUint16),
|
||||
trackerUUIDs: []uuid.UUID{firstUUID},
|
||||
currentUUID: uuid.New(),
|
||||
ipaddr: nil,
|
||||
ipv4: false,
|
||||
network: "ip",
|
||||
protocol: "udp",
|
||||
awaitingSequences: firstSequence,
|
||||
awaitingSequences: make(map[string]struct{}),
|
||||
TTL: 64,
|
||||
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
|
||||
}
|
||||
@ -189,14 +187,14 @@ type Pinger struct {
|
||||
ipaddr *net.IPAddr
|
||||
addr string
|
||||
|
||||
// trackerUUIDs is the list of UUIDs being used for sending packets.
|
||||
trackerUUIDs []uuid.UUID
|
||||
// currentUUID is the current UUID used to build unique and recognizable packet payloads
|
||||
currentUUID uuid.UUID
|
||||
|
||||
ipv4 bool
|
||||
id int
|
||||
sequence int
|
||||
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
|
||||
awaitingSequences map[uuid.UUID]map[int]struct{}
|
||||
awaitingSequences map[string]struct{}
|
||||
// network is one of "ip", "ip4", or "ip6".
|
||||
network string
|
||||
// protocol is "icmp" or "udp".
|
||||
@ -607,27 +605,6 @@ func (p *Pinger) recvICMP(
|
||||
}
|
||||
}
|
||||
|
||||
// getPacketUUID scans the tracking slice for matches.
|
||||
func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
|
||||
var packetUUID uuid.UUID
|
||||
err := packetUUID.UnmarshalBinary(pkt[timeSliceLength : timeSliceLength+trackerLength])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
|
||||
}
|
||||
|
||||
for _, item := range p.trackerUUIDs {
|
||||
if item == packetUUID {
|
||||
return &packetUUID, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getCurrentTrackerUUID grabs the latest tracker UUID.
|
||||
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
|
||||
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
|
||||
}
|
||||
|
||||
func (p *Pinger) processPacket(recv *packet) error {
|
||||
receivedAt := time.Now()
|
||||
var proto int
|
||||
@ -667,24 +644,29 @@ func (p *Pinger) processPacket(recv *packet) error {
|
||||
len(pkt.Data), pkt.Data)
|
||||
}
|
||||
|
||||
pktUUID, err := p.getPacketUUID(pkt.Data)
|
||||
if err != nil || pktUUID == nil {
|
||||
return err
|
||||
var pktUUID uuid.UUID
|
||||
err = pktUUID.UnmarshalBinary(pkt.Data[timeSliceLength : timeSliceLength+trackerLength])
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding tracking UUID: %w", err)
|
||||
}
|
||||
|
||||
timestamp := bytesToTime(pkt.Data[:timeSliceLength])
|
||||
inPkt.Rtt = receivedAt.Sub(timestamp)
|
||||
inPkt.Seq = pkt.Seq
|
||||
|
||||
key := buildLookupKey(pktUUID, pkt.Seq)
|
||||
|
||||
// If we've already received this sequence, ignore it.
|
||||
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
|
||||
if _, inflight := p.awaitingSequences[key]; !inflight {
|
||||
p.PacketsRecvDuplicates++
|
||||
if p.OnDuplicateRecv != nil {
|
||||
p.OnDuplicateRecv(inPkt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// remove it from the list of sequences we're waiting for so we don't get duplicates.
|
||||
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
|
||||
|
||||
// remove it from the list of sequences we're waiting for, so we don't get duplicates.
|
||||
delete(p.awaitingSequences, key)
|
||||
p.updateStatistics(inPkt)
|
||||
default:
|
||||
// Very bad, not sure how this can happen
|
||||
@ -705,8 +687,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
|
||||
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
|
||||
}
|
||||
|
||||
currentUUID := p.getCurrentTrackerUUID()
|
||||
uuidEncoded, err := currentUUID.MarshalBinary()
|
||||
uuidEncoded, err := p.currentUUID.MarshalBinary()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to marshal UUID binary: %w", err)
|
||||
}
|
||||
@ -753,13 +734,11 @@ func (p *Pinger) sendICMP(conn packetConn) error {
|
||||
handler(outPkt)
|
||||
}
|
||||
// mark this sequence as in-flight
|
||||
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
|
||||
p.awaitingSequences[buildLookupKey(p.currentUUID, p.sequence)] = struct{}{}
|
||||
p.PacketsSent++
|
||||
p.sequence++
|
||||
if p.sequence > 65535 {
|
||||
newUUID := uuid.New()
|
||||
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
|
||||
p.awaitingSequences[newUUID] = make(map[int]struct{})
|
||||
p.currentUUID = uuid.New()
|
||||
p.sequence = 0
|
||||
}
|
||||
break
|
||||
@ -818,3 +797,8 @@ var seed int64 = time.Now().UnixNano()
|
||||
func getSeed() int64 {
|
||||
return atomic.AddInt64(&seed, 1)
|
||||
}
|
||||
|
||||
// buildLookupKey builds the key required for lookups on awaitingSequences map
|
||||
func buildLookupKey(id uuid.UUID, sequenceId int) string {
|
||||
return string(id[:]) + strconv.Itoa(sequenceId)
|
||||
}
|
||||
|
32
ping_test.go
32
ping_test.go
@ -23,8 +23,7 @@ func TestProcessPacket(t *testing.T) {
|
||||
shouldBe1++
|
||||
}
|
||||
|
||||
currentUUID := pinger.getCurrentTrackerUUID()
|
||||
uuidEncoded, err := currentUUID.MarshalBinary()
|
||||
uuidEncoded, err := pinger.currentUUID.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||
}
|
||||
@ -38,7 +37,7 @@ func TestProcessPacket(t *testing.T) {
|
||||
Seq: pinger.sequence,
|
||||
Data: data,
|
||||
}
|
||||
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
|
||||
pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, pinger.sequence)] = struct{}{}
|
||||
|
||||
msg := &icmp.Message{
|
||||
Type: ipv4.ICMPTypeEchoReply,
|
||||
@ -67,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) {
|
||||
shouldBe0++
|
||||
}
|
||||
|
||||
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
|
||||
currentUUID, err := pinger.currentUUID.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||
}
|
||||
@ -110,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) {
|
||||
shouldBe0++
|
||||
}
|
||||
|
||||
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
|
||||
currentUUID, err := pinger.currentUUID.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||
}
|
||||
@ -190,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) {
|
||||
pinger := makeTestPinger()
|
||||
pinger.Size = 4096
|
||||
|
||||
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
|
||||
currentUUID, err := pinger.currentUUID.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||
}
|
||||
@ -543,7 +542,7 @@ func BenchmarkProcessPacket(b *testing.B) {
|
||||
pinger.protocol = "ip4:icmp"
|
||||
pinger.id = 123
|
||||
|
||||
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
|
||||
currentUUID, err := pinger.currentUUID.MarshalBinary()
|
||||
if err != nil {
|
||||
b.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||
}
|
||||
@ -573,7 +572,7 @@ func BenchmarkProcessPacket(b *testing.B) {
|
||||
}
|
||||
|
||||
for k := 0; k < b.N; k++ {
|
||||
pinger.processPacket(&pkt)
|
||||
_ = pinger.processPacket(&pkt)
|
||||
}
|
||||
}
|
||||
|
||||
@ -592,8 +591,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
|
||||
dups++
|
||||
}
|
||||
|
||||
currentUUID := pinger.getCurrentTrackerUUID()
|
||||
uuidEncoded, err := currentUUID.MarshalBinary()
|
||||
uuidEncoded, err := pinger.currentUUID.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||
}
|
||||
@ -608,7 +606,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
|
||||
Data: data,
|
||||
}
|
||||
// register the sequence as sent
|
||||
pinger.awaitingSequences[currentUUID][0] = struct{}{}
|
||||
pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, 0)] = struct{}{}
|
||||
|
||||
msg := &icmp.Message{
|
||||
Type: ipv4.ICMPTypeEchoReply,
|
||||
@ -640,14 +638,14 @@ type testPacketConn struct{}
|
||||
func (c testPacketConn) Close() error { return nil }
|
||||
func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho }
|
||||
func (c testPacketConn) SetFlagTTL() error { return nil }
|
||||
func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c testPacketConn) SetTTL(t int) {}
|
||||
func (c testPacketConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (c testPacketConn) SetTTL(_ int) {}
|
||||
|
||||
func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
|
||||
func (c testPacketConn) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) {
|
||||
return 0, 0, nil, nil
|
||||
}
|
||||
|
||||
func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) {
|
||||
func (c testPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
@ -655,7 +653,7 @@ type testPacketConnBadWrite struct {
|
||||
testPacketConn
|
||||
}
|
||||
|
||||
func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) {
|
||||
func (c testPacketConnBadWrite) WriteTo(_ []byte, _ net.Addr) (int, error) {
|
||||
return 0, errors.New("bad write")
|
||||
}
|
||||
|
||||
@ -684,7 +682,7 @@ type testPacketConnBadRead struct {
|
||||
testPacketConn
|
||||
}
|
||||
|
||||
func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
|
||||
func (c testPacketConnBadRead) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) {
|
||||
return 0, 0, nil, errors.New("bad read")
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user