mirror of
https://github.com/go-ping/ping.git
synced 2025-07-05 02:06:17 +00:00
refactor: flatten awaitingSequences map
Not nesting this map anymore has the advantages of making the access logic simpler and avoiding memory leaks due to empty maps being referred to by former UUIDs not being removed from the root list.
This commit is contained in:
parent
b89bb75386
commit
1f8d90182d
68
ping.go
68
ping.go
@ -60,6 +60,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
@ -87,25 +88,22 @@ var (
|
|||||||
// New returns a new Pinger struct pointer.
|
// New returns a new Pinger struct pointer.
|
||||||
func New(addr string) *Pinger {
|
func New(addr string) *Pinger {
|
||||||
r := rand.New(rand.NewSource(getSeed()))
|
r := rand.New(rand.NewSource(getSeed()))
|
||||||
firstUUID := uuid.New()
|
|
||||||
var firstSequence = map[uuid.UUID]map[int]struct{}{}
|
|
||||||
firstSequence[firstUUID] = make(map[int]struct{})
|
|
||||||
return &Pinger{
|
return &Pinger{
|
||||||
Count: -1,
|
Count: -1,
|
||||||
Interval: time.Second,
|
Interval: time.Second,
|
||||||
RecordRtts: true,
|
RecordRtts: true,
|
||||||
Size: timeSliceLength + trackerLength,
|
Size: timeSliceLength + trackerLength,
|
||||||
Timeout: time.Duration(math.MaxInt64),
|
Timeout: time.Duration(math.MaxInt64),
|
||||||
|
|
||||||
addr: addr,
|
addr: addr,
|
||||||
done: make(chan interface{}),
|
done: make(chan interface{}),
|
||||||
id: r.Intn(math.MaxUint16),
|
id: r.Intn(math.MaxUint16),
|
||||||
trackerUUIDs: []uuid.UUID{firstUUID},
|
currentUUID: uuid.New(),
|
||||||
ipaddr: nil,
|
ipaddr: nil,
|
||||||
ipv4: false,
|
ipv4: false,
|
||||||
network: "ip",
|
network: "ip",
|
||||||
protocol: "udp",
|
protocol: "udp",
|
||||||
awaitingSequences: firstSequence,
|
awaitingSequences: make(map[string]struct{}),
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
|
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
|
||||||
}
|
}
|
||||||
@ -189,14 +187,14 @@ type Pinger struct {
|
|||||||
ipaddr *net.IPAddr
|
ipaddr *net.IPAddr
|
||||||
addr string
|
addr string
|
||||||
|
|
||||||
// trackerUUIDs is the list of UUIDs being used for sending packets.
|
// currentUUID is the current UUID used to build unique and recognizable packet payloads
|
||||||
trackerUUIDs []uuid.UUID
|
currentUUID uuid.UUID
|
||||||
|
|
||||||
ipv4 bool
|
ipv4 bool
|
||||||
id int
|
id int
|
||||||
sequence int
|
sequence int
|
||||||
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
|
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
|
||||||
awaitingSequences map[uuid.UUID]map[int]struct{}
|
awaitingSequences map[string]struct{}
|
||||||
// network is one of "ip", "ip4", or "ip6".
|
// network is one of "ip", "ip4", or "ip6".
|
||||||
network string
|
network string
|
||||||
// protocol is "icmp" or "udp".
|
// protocol is "icmp" or "udp".
|
||||||
@ -607,27 +605,6 @@ func (p *Pinger) recvICMP(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPacketUUID scans the tracking slice for matches.
|
|
||||||
func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
|
|
||||||
var packetUUID uuid.UUID
|
|
||||||
err := packetUUID.UnmarshalBinary(pkt[timeSliceLength : timeSliceLength+trackerLength])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, item := range p.trackerUUIDs {
|
|
||||||
if item == packetUUID {
|
|
||||||
return &packetUUID, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getCurrentTrackerUUID grabs the latest tracker UUID.
|
|
||||||
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
|
|
||||||
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pinger) processPacket(recv *packet) error {
|
func (p *Pinger) processPacket(recv *packet) error {
|
||||||
receivedAt := time.Now()
|
receivedAt := time.Now()
|
||||||
var proto int
|
var proto int
|
||||||
@ -667,24 +644,29 @@ func (p *Pinger) processPacket(recv *packet) error {
|
|||||||
len(pkt.Data), pkt.Data)
|
len(pkt.Data), pkt.Data)
|
||||||
}
|
}
|
||||||
|
|
||||||
pktUUID, err := p.getPacketUUID(pkt.Data)
|
var pktUUID uuid.UUID
|
||||||
if err != nil || pktUUID == nil {
|
err = pktUUID.UnmarshalBinary(pkt.Data[timeSliceLength : timeSliceLength+trackerLength])
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("error decoding tracking UUID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
timestamp := bytesToTime(pkt.Data[:timeSliceLength])
|
timestamp := bytesToTime(pkt.Data[:timeSliceLength])
|
||||||
inPkt.Rtt = receivedAt.Sub(timestamp)
|
inPkt.Rtt = receivedAt.Sub(timestamp)
|
||||||
inPkt.Seq = pkt.Seq
|
inPkt.Seq = pkt.Seq
|
||||||
|
|
||||||
|
key := buildLookupKey(pktUUID, pkt.Seq)
|
||||||
|
|
||||||
// If we've already received this sequence, ignore it.
|
// If we've already received this sequence, ignore it.
|
||||||
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
|
if _, inflight := p.awaitingSequences[key]; !inflight {
|
||||||
p.PacketsRecvDuplicates++
|
p.PacketsRecvDuplicates++
|
||||||
if p.OnDuplicateRecv != nil {
|
if p.OnDuplicateRecv != nil {
|
||||||
p.OnDuplicateRecv(inPkt)
|
p.OnDuplicateRecv(inPkt)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// remove it from the list of sequences we're waiting for so we don't get duplicates.
|
|
||||||
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
|
// remove it from the list of sequences we're waiting for, so we don't get duplicates.
|
||||||
|
delete(p.awaitingSequences, key)
|
||||||
p.updateStatistics(inPkt)
|
p.updateStatistics(inPkt)
|
||||||
default:
|
default:
|
||||||
// Very bad, not sure how this can happen
|
// Very bad, not sure how this can happen
|
||||||
@ -705,8 +687,7 @@ func (p *Pinger) sendICMP(conn packetConn) error {
|
|||||||
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
|
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
|
||||||
}
|
}
|
||||||
|
|
||||||
currentUUID := p.getCurrentTrackerUUID()
|
uuidEncoded, err := p.currentUUID.MarshalBinary()
|
||||||
uuidEncoded, err := currentUUID.MarshalBinary()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to marshal UUID binary: %w", err)
|
return fmt.Errorf("unable to marshal UUID binary: %w", err)
|
||||||
}
|
}
|
||||||
@ -753,13 +734,11 @@ func (p *Pinger) sendICMP(conn packetConn) error {
|
|||||||
handler(outPkt)
|
handler(outPkt)
|
||||||
}
|
}
|
||||||
// mark this sequence as in-flight
|
// mark this sequence as in-flight
|
||||||
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
|
p.awaitingSequences[buildLookupKey(p.currentUUID, p.sequence)] = struct{}{}
|
||||||
p.PacketsSent++
|
p.PacketsSent++
|
||||||
p.sequence++
|
p.sequence++
|
||||||
if p.sequence > 65535 {
|
if p.sequence > 65535 {
|
||||||
newUUID := uuid.New()
|
p.currentUUID = uuid.New()
|
||||||
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
|
|
||||||
p.awaitingSequences[newUUID] = make(map[int]struct{})
|
|
||||||
p.sequence = 0
|
p.sequence = 0
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@ -818,3 +797,8 @@ var seed int64 = time.Now().UnixNano()
|
|||||||
func getSeed() int64 {
|
func getSeed() int64 {
|
||||||
return atomic.AddInt64(&seed, 1)
|
return atomic.AddInt64(&seed, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildLookupKey builds the key required for lookups on awaitingSequences map
|
||||||
|
func buildLookupKey(id uuid.UUID, sequenceId int) string {
|
||||||
|
return string(id[:]) + strconv.Itoa(sequenceId)
|
||||||
|
}
|
||||||
|
32
ping_test.go
32
ping_test.go
@ -23,8 +23,7 @@ func TestProcessPacket(t *testing.T) {
|
|||||||
shouldBe1++
|
shouldBe1++
|
||||||
}
|
}
|
||||||
|
|
||||||
currentUUID := pinger.getCurrentTrackerUUID()
|
uuidEncoded, err := pinger.currentUUID.MarshalBinary()
|
||||||
uuidEncoded, err := currentUUID.MarshalBinary()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||||
}
|
}
|
||||||
@ -38,7 +37,7 @@ func TestProcessPacket(t *testing.T) {
|
|||||||
Seq: pinger.sequence,
|
Seq: pinger.sequence,
|
||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
|
pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, pinger.sequence)] = struct{}{}
|
||||||
|
|
||||||
msg := &icmp.Message{
|
msg := &icmp.Message{
|
||||||
Type: ipv4.ICMPTypeEchoReply,
|
Type: ipv4.ICMPTypeEchoReply,
|
||||||
@ -67,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) {
|
|||||||
shouldBe0++
|
shouldBe0++
|
||||||
}
|
}
|
||||||
|
|
||||||
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
|
currentUUID, err := pinger.currentUUID.MarshalBinary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||||
}
|
}
|
||||||
@ -110,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) {
|
|||||||
shouldBe0++
|
shouldBe0++
|
||||||
}
|
}
|
||||||
|
|
||||||
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
|
currentUUID, err := pinger.currentUUID.MarshalBinary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||||
}
|
}
|
||||||
@ -190,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) {
|
|||||||
pinger := makeTestPinger()
|
pinger := makeTestPinger()
|
||||||
pinger.Size = 4096
|
pinger.Size = 4096
|
||||||
|
|
||||||
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
|
currentUUID, err := pinger.currentUUID.MarshalBinary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||||
}
|
}
|
||||||
@ -543,7 +542,7 @@ func BenchmarkProcessPacket(b *testing.B) {
|
|||||||
pinger.protocol = "ip4:icmp"
|
pinger.protocol = "ip4:icmp"
|
||||||
pinger.id = 123
|
pinger.id = 123
|
||||||
|
|
||||||
currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary()
|
currentUUID, err := pinger.currentUUID.MarshalBinary()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
b.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||||
}
|
}
|
||||||
@ -573,7 +572,7 @@ func BenchmarkProcessPacket(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for k := 0; k < b.N; k++ {
|
for k := 0; k < b.N; k++ {
|
||||||
pinger.processPacket(&pkt)
|
_ = pinger.processPacket(&pkt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -592,8 +591,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
|
|||||||
dups++
|
dups++
|
||||||
}
|
}
|
||||||
|
|
||||||
currentUUID := pinger.getCurrentTrackerUUID()
|
uuidEncoded, err := pinger.currentUUID.MarshalBinary()
|
||||||
uuidEncoded, err := currentUUID.MarshalBinary()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err))
|
||||||
}
|
}
|
||||||
@ -608,7 +606,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
|
|||||||
Data: data,
|
Data: data,
|
||||||
}
|
}
|
||||||
// register the sequence as sent
|
// register the sequence as sent
|
||||||
pinger.awaitingSequences[currentUUID][0] = struct{}{}
|
pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, 0)] = struct{}{}
|
||||||
|
|
||||||
msg := &icmp.Message{
|
msg := &icmp.Message{
|
||||||
Type: ipv4.ICMPTypeEchoReply,
|
Type: ipv4.ICMPTypeEchoReply,
|
||||||
@ -640,14 +638,14 @@ type testPacketConn struct{}
|
|||||||
func (c testPacketConn) Close() error { return nil }
|
func (c testPacketConn) Close() error { return nil }
|
||||||
func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho }
|
func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho }
|
||||||
func (c testPacketConn) SetFlagTTL() error { return nil }
|
func (c testPacketConn) SetFlagTTL() error { return nil }
|
||||||
func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil }
|
func (c testPacketConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||||
func (c testPacketConn) SetTTL(t int) {}
|
func (c testPacketConn) SetTTL(_ int) {}
|
||||||
|
|
||||||
func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
|
func (c testPacketConn) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) {
|
||||||
return 0, 0, nil, nil
|
return 0, 0, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) {
|
func (c testPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) {
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -655,7 +653,7 @@ type testPacketConnBadWrite struct {
|
|||||||
testPacketConn
|
testPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) {
|
func (c testPacketConnBadWrite) WriteTo(_ []byte, _ net.Addr) (int, error) {
|
||||||
return 0, errors.New("bad write")
|
return 0, errors.New("bad write")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -684,7 +682,7 @@ type testPacketConnBadRead struct {
|
|||||||
testPacketConn
|
testPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
|
func (c testPacketConnBadRead) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) {
|
||||||
return 0, 0, nil, errors.New("bad read")
|
return 0, 0, nil, errors.New("bad read")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user