diff --git a/ping.go b/ping.go index ef277ab..2854e0a 100644 --- a/ping.go +++ b/ping.go @@ -60,6 +60,7 @@ import ( "math" "math/rand" "net" + "strconv" "sync" "sync/atomic" "syscall" @@ -87,25 +88,22 @@ var ( // New returns a new Pinger struct pointer. func New(addr string) *Pinger { r := rand.New(rand.NewSource(getSeed())) - firstUUID := uuid.New() - var firstSequence = map[uuid.UUID]map[int]struct{}{} - firstSequence[firstUUID] = make(map[int]struct{}) - return &Pinger{ - Count: -1, - Interval: time.Second, - RecordRtts: true, - Size: timeSliceLength + trackerLength, - Timeout: time.Duration(math.MaxInt64), + return &Pinger{ + Count: -1, + Interval: time.Second, + RecordRtts: true, + Size: timeSliceLength + trackerLength, + Timeout: time.Duration(math.MaxInt64), addr: addr, done: make(chan interface{}), id: r.Intn(math.MaxUint16), - trackerUUIDs: []uuid.UUID{firstUUID}, + currentUUID: uuid.New(), ipaddr: nil, ipv4: false, network: "ip", protocol: "udp", - awaitingSequences: firstSequence, + awaitingSequences: make(map[string]struct{}), TTL: 64, logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())}, } @@ -189,14 +187,14 @@ type Pinger struct { ipaddr *net.IPAddr addr string - // trackerUUIDs is the list of UUIDs being used for sending packets. - trackerUUIDs []uuid.UUID + // currentUUID is the current UUID used to build unique and recognizable packet payloads + currentUUID uuid.UUID ipv4 bool id int sequence int // awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts - awaitingSequences map[uuid.UUID]map[int]struct{} + awaitingSequences map[string]struct{} // network is one of "ip", "ip4", or "ip6". network string // protocol is "icmp" or "udp". @@ -607,27 +605,6 @@ func (p *Pinger) recvICMP( } } -// getPacketUUID scans the tracking slice for matches. -func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) { - var packetUUID uuid.UUID - err := packetUUID.UnmarshalBinary(pkt[timeSliceLength : timeSliceLength+trackerLength]) - if err != nil { - return nil, fmt.Errorf("error decoding tracking UUID: %w", err) - } - - for _, item := range p.trackerUUIDs { - if item == packetUUID { - return &packetUUID, nil - } - } - return nil, nil -} - -// getCurrentTrackerUUID grabs the latest tracker UUID. -func (p *Pinger) getCurrentTrackerUUID() uuid.UUID { - return p.trackerUUIDs[len(p.trackerUUIDs)-1] -} - func (p *Pinger) processPacket(recv *packet) error { receivedAt := time.Now() var proto int @@ -667,24 +644,29 @@ func (p *Pinger) processPacket(recv *packet) error { len(pkt.Data), pkt.Data) } - pktUUID, err := p.getPacketUUID(pkt.Data) - if err != nil || pktUUID == nil { - return err + var pktUUID uuid.UUID + err = pktUUID.UnmarshalBinary(pkt.Data[timeSliceLength : timeSliceLength+trackerLength]) + if err != nil { + return fmt.Errorf("error decoding tracking UUID: %w", err) } timestamp := bytesToTime(pkt.Data[:timeSliceLength]) inPkt.Rtt = receivedAt.Sub(timestamp) inPkt.Seq = pkt.Seq + + key := buildLookupKey(pktUUID, pkt.Seq) + // If we've already received this sequence, ignore it. - if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight { + if _, inflight := p.awaitingSequences[key]; !inflight { p.PacketsRecvDuplicates++ if p.OnDuplicateRecv != nil { p.OnDuplicateRecv(inPkt) } return nil } - // remove it from the list of sequences we're waiting for so we don't get duplicates. - delete(p.awaitingSequences[*pktUUID], pkt.Seq) + + // remove it from the list of sequences we're waiting for, so we don't get duplicates. + delete(p.awaitingSequences, key) p.updateStatistics(inPkt) default: // Very bad, not sure how this can happen @@ -705,8 +687,7 @@ func (p *Pinger) sendICMP(conn packetConn) error { dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} } - currentUUID := p.getCurrentTrackerUUID() - uuidEncoded, err := currentUUID.MarshalBinary() + uuidEncoded, err := p.currentUUID.MarshalBinary() if err != nil { return fmt.Errorf("unable to marshal UUID binary: %w", err) } @@ -753,13 +734,11 @@ func (p *Pinger) sendICMP(conn packetConn) error { handler(outPkt) } // mark this sequence as in-flight - p.awaitingSequences[currentUUID][p.sequence] = struct{}{} + p.awaitingSequences[buildLookupKey(p.currentUUID, p.sequence)] = struct{}{} p.PacketsSent++ p.sequence++ if p.sequence > 65535 { - newUUID := uuid.New() - p.trackerUUIDs = append(p.trackerUUIDs, newUUID) - p.awaitingSequences[newUUID] = make(map[int]struct{}) + p.currentUUID = uuid.New() p.sequence = 0 } break @@ -818,3 +797,8 @@ var seed int64 = time.Now().UnixNano() func getSeed() int64 { return atomic.AddInt64(&seed, 1) } + +// buildLookupKey builds the key required for lookups on awaitingSequences map +func buildLookupKey(id uuid.UUID, sequenceId int) string { + return string(id[:]) + strconv.Itoa(sequenceId) +} diff --git a/ping_test.go b/ping_test.go index b8755e7..073aa6c 100644 --- a/ping_test.go +++ b/ping_test.go @@ -23,8 +23,7 @@ func TestProcessPacket(t *testing.T) { shouldBe1++ } - currentUUID := pinger.getCurrentTrackerUUID() - uuidEncoded, err := currentUUID.MarshalBinary() + uuidEncoded, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -38,7 +37,7 @@ func TestProcessPacket(t *testing.T) { Seq: pinger.sequence, Data: data, } - pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{} + pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, pinger.sequence)] = struct{}{} msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, @@ -67,7 +66,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) { shouldBe0++ } - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -110,7 +109,7 @@ func TestProcessPacket_IDMismatch(t *testing.T) { shouldBe0++ } - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -190,7 +189,7 @@ func TestProcessPacket_LargePacket(t *testing.T) { pinger := makeTestPinger() pinger.Size = 4096 - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -543,7 +542,7 @@ func BenchmarkProcessPacket(b *testing.B) { pinger.protocol = "ip4:icmp" pinger.id = 123 - currentUUID, err := pinger.getCurrentTrackerUUID().MarshalBinary() + currentUUID, err := pinger.currentUUID.MarshalBinary() if err != nil { b.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -573,7 +572,7 @@ func BenchmarkProcessPacket(b *testing.B) { } for k := 0; k < b.N; k++ { - pinger.processPacket(&pkt) + _ = pinger.processPacket(&pkt) } } @@ -592,8 +591,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { dups++ } - currentUUID := pinger.getCurrentTrackerUUID() - uuidEncoded, err := currentUUID.MarshalBinary() + uuidEncoded, err := pinger.currentUUID.MarshalBinary() if err != nil { t.Fatal(fmt.Sprintf("unable to marshal UUID binary: %s", err)) } @@ -608,7 +606,7 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { Data: data, } // register the sequence as sent - pinger.awaitingSequences[currentUUID][0] = struct{}{} + pinger.awaitingSequences[buildLookupKey(pinger.currentUUID, 0)] = struct{}{} msg := &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, @@ -640,14 +638,14 @@ type testPacketConn struct{} func (c testPacketConn) Close() error { return nil } func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho } func (c testPacketConn) SetFlagTTL() error { return nil } -func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil } -func (c testPacketConn) SetTTL(t int) {} +func (c testPacketConn) SetReadDeadline(_ time.Time) error { return nil } +func (c testPacketConn) SetTTL(_ int) {} -func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { +func (c testPacketConn) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, nil, nil } -func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) { +func (c testPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) { return len(b), nil } @@ -655,7 +653,7 @@ type testPacketConnBadWrite struct { testPacketConn } -func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) { +func (c testPacketConnBadWrite) WriteTo(_ []byte, _ net.Addr) (int, error) { return 0, errors.New("bad write") } @@ -684,7 +682,7 @@ type testPacketConnBadRead struct { testPacketConn } -func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { +func (c testPacketConnBadRead) ReadFrom(_ []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, nil, errors.New("bad read") }