mirror of
https://github.com/go-ping/ping.git
synced 2025-04-28 02:40:17 +00:00
Provide an abstraction over icmp.PacketConn (#166)
The differences between IPv4 and IPv6 APIs can be moved to a single type so that we don't need to keep track of them all over the code. We can also split Run() into two parts: the top one sets up the listener and the bottom one sends and receives packets. In this way, the bottom part can be tested using a mock packet connection.
This commit is contained in:
parent
e4e642a957
commit
ff8be33200
5
go.mod
5
go.mod
@ -2,4 +2,7 @@ module github.com/go-ping/ping
|
||||
|
||||
go 1.14
|
||||
|
||||
require golang.org/x/net v0.0.0-20200904194848-62affa334b73
|
||||
require (
|
||||
golang.org/x/net v0.0.0-20200904194848-62affa334b73
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
|
||||
)
|
||||
|
2
go.sum
2
go.sum
@ -3,6 +3,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA=
|
||||
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
|
||||
|
86
packetconn.go
Normal file
86
packetconn.go
Normal file
@ -0,0 +1,86 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
type packetConn interface {
|
||||
Close() error
|
||||
ICMPRequestType() icmp.Type
|
||||
ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error)
|
||||
SetFlagTTL() error
|
||||
SetReadDeadline(t time.Time) error
|
||||
WriteTo(b []byte, dst net.Addr) (int, error)
|
||||
}
|
||||
|
||||
type icmpConn struct {
|
||||
c *icmp.PacketConn
|
||||
}
|
||||
|
||||
func (c *icmpConn) Close() error {
|
||||
return c.c.Close()
|
||||
}
|
||||
|
||||
func (c *icmpConn) SetReadDeadline(t time.Time) error {
|
||||
return c.c.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) {
|
||||
return c.c.WriteTo(b, dst)
|
||||
}
|
||||
|
||||
type icmpv4Conn struct {
|
||||
icmpConn
|
||||
}
|
||||
|
||||
func (c *icmpv4Conn) SetFlagTTL() error {
|
||||
err := c.c.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true)
|
||||
if runtime.GOOS == "windows" {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
|
||||
var ttl int
|
||||
n, cm, src, err := c.c.IPv4PacketConn().ReadFrom(b)
|
||||
if cm != nil {
|
||||
ttl = cm.TTL
|
||||
}
|
||||
return n, ttl, src, err
|
||||
}
|
||||
|
||||
func (c icmpv4Conn) ICMPRequestType() icmp.Type {
|
||||
return ipv4.ICMPTypeEcho
|
||||
}
|
||||
|
||||
type icmpV6Conn struct {
|
||||
icmpConn
|
||||
}
|
||||
|
||||
func (c *icmpV6Conn) SetFlagTTL() error {
|
||||
err := c.c.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true)
|
||||
if runtime.GOOS == "windows" {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
|
||||
var ttl int
|
||||
n, cm, src, err := c.c.IPv6PacketConn().ReadFrom(b)
|
||||
if cm != nil {
|
||||
ttl = cm.HopLimit
|
||||
}
|
||||
return n, ttl, src, err
|
||||
}
|
||||
|
||||
func (c icmpV6Conn) ICMPRequestType() icmp.Type {
|
||||
return ipv6.ICMPTypeEchoRequest
|
||||
}
|
125
ping.go
125
ping.go
@ -61,7 +61,6 @@ import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
@ -70,6 +69,7 @@ import (
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -380,12 +380,7 @@ func (p *Pinger) SetLogger(logger Logger) {
|
||||
// done. If Count or Interval are not specified, it will run continuously until
|
||||
// it is interrupted.
|
||||
func (p *Pinger) Run() error {
|
||||
logger := p.logger
|
||||
if logger == nil {
|
||||
logger = NoopLogger{}
|
||||
}
|
||||
|
||||
var conn *icmp.PacketConn
|
||||
var conn packetConn
|
||||
var err error
|
||||
if p.ipaddr == nil {
|
||||
err = p.Resolve()
|
||||
@ -393,46 +388,60 @@ func (p *Pinger) Run() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if p.ipv4 {
|
||||
if conn, err = p.listen(ipv4Proto[p.protocol]); err != nil {
|
||||
if conn, err = p.listen(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = conn.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true); runtime.GOOS != "windows" && err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if conn, err = p.listen(ipv6Proto[p.protocol]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = conn.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true); runtime.GOOS != "windows" && err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return p.run(conn)
|
||||
}
|
||||
|
||||
func (p *Pinger) run(conn packetConn) error {
|
||||
if err := conn.SetFlagTTL(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer p.finish()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
recv := make(chan *packet, 5)
|
||||
defer close(recv)
|
||||
wg.Add(1)
|
||||
//nolint:errcheck
|
||||
go p.recvICMP(conn, recv, &wg)
|
||||
|
||||
if handler := p.OnSetup; handler != nil {
|
||||
handler()
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
|
||||
g.Go(func() error {
|
||||
defer p.Stop()
|
||||
return p.recvICMP(conn, recv)
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
defer p.Stop()
|
||||
return p.runLoop(conn, recv)
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (p *Pinger) runLoop(
|
||||
conn packetConn,
|
||||
recvCh <-chan *packet,
|
||||
) error {
|
||||
logger := p.logger
|
||||
if logger == nil {
|
||||
logger = NoopLogger{}
|
||||
}
|
||||
|
||||
timeout := time.NewTicker(p.Timeout)
|
||||
interval := time.NewTicker(p.Interval)
|
||||
defer func() {
|
||||
p.Stop()
|
||||
interval.Stop()
|
||||
timeout.Stop()
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
err = p.sendICMP(conn)
|
||||
if err != nil {
|
||||
if err := p.sendICMP(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -440,20 +449,23 @@ func (p *Pinger) Run() error {
|
||||
select {
|
||||
case <-p.done:
|
||||
return nil
|
||||
|
||||
case <-timeout.C:
|
||||
return nil
|
||||
case r := <-recv:
|
||||
|
||||
case r := <-recvCh:
|
||||
err := p.processPacket(r)
|
||||
if err != nil {
|
||||
// FIXME: this logs as FATAL but continues
|
||||
logger.Fatalf("processing received packet: %s", err)
|
||||
}
|
||||
|
||||
case <-interval.C:
|
||||
if p.Count > 0 && p.PacketsSent >= p.Count {
|
||||
interval.Stop()
|
||||
continue
|
||||
}
|
||||
err = p.sendICMP(conn)
|
||||
err := p.sendICMP(conn)
|
||||
if err != nil {
|
||||
// FIXME: this logs as FATAL but continues
|
||||
logger.Fatalf("sending packet: %s", err)
|
||||
@ -531,12 +543,9 @@ func newExpBackoff(baseDelay time.Duration, maxExp int64) expBackoff {
|
||||
}
|
||||
|
||||
func (p *Pinger) recvICMP(
|
||||
conn *icmp.PacketConn,
|
||||
conn packetConn,
|
||||
recv chan<- *packet,
|
||||
wg *sync.WaitGroup,
|
||||
) error {
|
||||
defer wg.Done()
|
||||
|
||||
// Start by waiting for 50 µs and increase to a possible maximum of ~ 100 ms.
|
||||
expBackoff := newExpBackoff(50*time.Microsecond, 11)
|
||||
delay := expBackoff.Get()
|
||||
@ -552,31 +561,17 @@ func (p *Pinger) recvICMP(
|
||||
}
|
||||
var n, ttl int
|
||||
var err error
|
||||
if p.ipv4 {
|
||||
var cm *ipv4.ControlMessage
|
||||
n, cm, _, err = conn.IPv4PacketConn().ReadFrom(bytes)
|
||||
if cm != nil {
|
||||
ttl = cm.TTL
|
||||
}
|
||||
} else {
|
||||
var cm *ipv6.ControlMessage
|
||||
n, cm, _, err = conn.IPv6PacketConn().ReadFrom(bytes)
|
||||
if cm != nil {
|
||||
ttl = cm.HopLimit
|
||||
}
|
||||
}
|
||||
n, ttl, _, err = conn.ReadFrom(bytes)
|
||||
if err != nil {
|
||||
if neterr, ok := err.(*net.OpError); ok {
|
||||
if neterr.Timeout() {
|
||||
// Read timeout
|
||||
delay = expBackoff.Get()
|
||||
continue
|
||||
} else {
|
||||
p.Stop()
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.done:
|
||||
@ -658,14 +653,7 @@ func (p *Pinger) processPacket(recv *packet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
|
||||
var typ icmp.Type
|
||||
if p.ipv4 {
|
||||
typ = ipv4.ICMPTypeEcho
|
||||
} else {
|
||||
typ = ipv6.ICMPTypeEchoRequest
|
||||
}
|
||||
|
||||
func (p *Pinger) sendICMP(conn packetConn) error {
|
||||
var dst net.Addr = p.ipaddr
|
||||
if p.protocol == "udp" {
|
||||
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
|
||||
@ -683,7 +671,7 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
|
||||
}
|
||||
|
||||
msg := &icmp.Message{
|
||||
Type: typ,
|
||||
Type: conn.ICMPRequestType(),
|
||||
Code: 0,
|
||||
Body: body,
|
||||
}
|
||||
@ -700,6 +688,7 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
handler := p.OnSend
|
||||
if handler != nil {
|
||||
@ -721,8 +710,22 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pinger) listen(netProto string) (*icmp.PacketConn, error) {
|
||||
conn, err := icmp.ListenPacket(netProto, p.Source)
|
||||
func (p *Pinger) listen() (packetConn, error) {
|
||||
var (
|
||||
conn packetConn
|
||||
err error
|
||||
)
|
||||
|
||||
if p.ipv4 {
|
||||
var c icmpv4Conn
|
||||
c.c, err = icmp.ListenPacket(ipv4Proto[p.protocol], p.Source)
|
||||
conn = &c
|
||||
} else {
|
||||
var c icmpV6Conn
|
||||
c.c, err = icmp.ListenPacket(ipv6Proto[p.protocol], p.Source)
|
||||
conn = &c
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
p.Stop()
|
||||
return nil, err
|
||||
|
136
ping_test.go
136
ping_test.go
@ -2,8 +2,10 @@ package ping
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -466,6 +468,7 @@ func makeTestPinger() *Pinger {
|
||||
}
|
||||
|
||||
func AssertNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Errorf("Expected No Error but got %s, Stack:\n%s",
|
||||
err, string(debug.Stack()))
|
||||
@ -473,6 +476,7 @@ func AssertNoError(t *testing.T, err error) {
|
||||
}
|
||||
|
||||
func AssertError(t *testing.T, err error, info string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
t.Errorf("Expected Error but got %s, %s, Stack:\n%s",
|
||||
err, info, string(debug.Stack()))
|
||||
@ -480,6 +484,7 @@ func AssertError(t *testing.T, err error, info string) {
|
||||
}
|
||||
|
||||
func AssertEqualStrings(t *testing.T, expected, actual string) {
|
||||
t.Helper()
|
||||
if expected != actual {
|
||||
t.Errorf("Expected %s, got %s, Stack:\n%s",
|
||||
expected, actual, string(debug.Stack()))
|
||||
@ -487,6 +492,7 @@ func AssertEqualStrings(t *testing.T, expected, actual string) {
|
||||
}
|
||||
|
||||
func AssertNotEqualStrings(t *testing.T, expected, actual string) {
|
||||
t.Helper()
|
||||
if expected == actual {
|
||||
t.Errorf("Expected %s, got %s, Stack:\n%s",
|
||||
expected, actual, string(debug.Stack()))
|
||||
@ -494,12 +500,14 @@ func AssertNotEqualStrings(t *testing.T, expected, actual string) {
|
||||
}
|
||||
|
||||
func AssertTrue(t *testing.T, b bool) {
|
||||
t.Helper()
|
||||
if !b {
|
||||
t.Errorf("Expected True, got False, Stack:\n%s", string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
|
||||
func AssertFalse(t *testing.T, b bool) {
|
||||
t.Helper()
|
||||
if b {
|
||||
t.Errorf("Expected False, got True, Stack:\n%s", string(debug.Stack()))
|
||||
}
|
||||
@ -596,3 +604,131 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
|
||||
AssertTrue(t, dups == 1)
|
||||
AssertTrue(t, pinger.PacketsRecvDuplicates == 1)
|
||||
}
|
||||
|
||||
type testPacketConn struct{}
|
||||
|
||||
func (c testPacketConn) Close() error { return nil }
|
||||
func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho }
|
||||
func (c testPacketConn) SetFlagTTL() error { return nil }
|
||||
func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
|
||||
return 0, 0, nil, nil
|
||||
}
|
||||
|
||||
func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
type testPacketConnBadWrite struct {
|
||||
testPacketConn
|
||||
}
|
||||
|
||||
func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) {
|
||||
return 0, errors.New("bad write")
|
||||
}
|
||||
|
||||
func TestRunBadWrite(t *testing.T) {
|
||||
pinger := New("127.0.0.1")
|
||||
pinger.Count = 1
|
||||
|
||||
err := pinger.Resolve()
|
||||
AssertNoError(t, err)
|
||||
|
||||
var conn testPacketConnBadWrite
|
||||
|
||||
err = pinger.run(conn)
|
||||
AssertTrue(t, err != nil)
|
||||
|
||||
stats := pinger.Statistics()
|
||||
AssertTrue(t, stats != nil)
|
||||
if stats == nil {
|
||||
t.FailNow()
|
||||
}
|
||||
AssertTrue(t, stats.PacketsSent == 0)
|
||||
AssertTrue(t, stats.PacketsRecv == 0)
|
||||
}
|
||||
|
||||
type testPacketConnBadRead struct {
|
||||
testPacketConn
|
||||
}
|
||||
|
||||
func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
|
||||
return 0, 0, nil, errors.New("bad read")
|
||||
}
|
||||
|
||||
func TestRunBadRead(t *testing.T) {
|
||||
pinger := New("127.0.0.1")
|
||||
pinger.Count = 1
|
||||
|
||||
err := pinger.Resolve()
|
||||
AssertNoError(t, err)
|
||||
|
||||
var conn testPacketConnBadRead
|
||||
|
||||
err = pinger.run(conn)
|
||||
AssertTrue(t, err != nil)
|
||||
|
||||
stats := pinger.Statistics()
|
||||
AssertTrue(t, stats != nil)
|
||||
if stats == nil {
|
||||
t.FailNow()
|
||||
}
|
||||
AssertTrue(t, stats.PacketsSent == 1)
|
||||
AssertTrue(t, stats.PacketsRecv == 0)
|
||||
}
|
||||
|
||||
type testPacketConnOK struct {
|
||||
testPacketConn
|
||||
writeDone int32
|
||||
buf []byte
|
||||
dst net.Addr
|
||||
}
|
||||
|
||||
func (c *testPacketConnOK) WriteTo(b []byte, dst net.Addr) (int, error) {
|
||||
c.buf = make([]byte, len(b))
|
||||
c.dst = dst
|
||||
n := copy(c.buf, b)
|
||||
atomic.StoreInt32(&c.writeDone, 1)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *testPacketConnOK) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
|
||||
if atomic.LoadInt32(&c.writeDone) == 0 {
|
||||
return 0, 0, nil, nil
|
||||
}
|
||||
msg, err := icmp.ParseMessage(ipv4.ICMPTypeEcho.Protocol(), c.buf)
|
||||
if err != nil {
|
||||
return 0, 0, nil, err
|
||||
}
|
||||
msg.Type = ipv4.ICMPTypeEchoReply
|
||||
buf, err := msg.Marshal(nil)
|
||||
if err != nil {
|
||||
return 0, 0, nil, err
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
return copy(b, buf), 64, c.dst, nil
|
||||
}
|
||||
|
||||
func TestRunOK(t *testing.T) {
|
||||
pinger := New("127.0.0.1")
|
||||
pinger.Count = 1
|
||||
|
||||
err := pinger.Resolve()
|
||||
AssertNoError(t, err)
|
||||
|
||||
conn := new(testPacketConnOK)
|
||||
|
||||
err = pinger.run(conn)
|
||||
AssertTrue(t, err == nil)
|
||||
|
||||
stats := pinger.Statistics()
|
||||
AssertTrue(t, stats != nil)
|
||||
if stats == nil {
|
||||
t.FailNow()
|
||||
}
|
||||
AssertTrue(t, stats.PacketsSent == 1)
|
||||
AssertTrue(t, stats.PacketsRecv == 1)
|
||||
AssertTrue(t, stats.MinRtt >= 10*time.Millisecond)
|
||||
AssertTrue(t, stats.MinRtt <= 12*time.Millisecond)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user