Provide an abstraction over icmp.PacketConn (#166)

The differences between IPv4 and IPv6 APIs can be moved to a single type so that we don't need to keep track of them all over the code. We can also split Run() into two parts: the top one sets up the listener and the bottom one sends and receives packets. In this way, the bottom part can be tested using a mock packet connection.
This commit is contained in:
Marcelo Magallon 2021-05-06 17:38:00 -06:00 committed by GitHub
parent e4e642a957
commit ff8be33200
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 292 additions and 62 deletions

5
go.mod
View File

@ -2,4 +2,7 @@ module github.com/go-ping/ping
go 1.14
require golang.org/x/net v0.0.0-20200904194848-62affa334b73
require (
golang.org/x/net v0.0.0-20200904194848-62affa334b73
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
)

2
go.sum
View File

@ -3,6 +3,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA=
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=

86
packetconn.go Normal file
View File

@ -0,0 +1,86 @@
package ping
import (
"net"
"runtime"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type packetConn interface {
Close() error
ICMPRequestType() icmp.Type
ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error)
SetFlagTTL() error
SetReadDeadline(t time.Time) error
WriteTo(b []byte, dst net.Addr) (int, error)
}
type icmpConn struct {
c *icmp.PacketConn
}
func (c *icmpConn) Close() error {
return c.c.Close()
}
func (c *icmpConn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t)
}
func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) {
return c.c.WriteTo(b, dst)
}
type icmpv4Conn struct {
icmpConn
}
func (c *icmpv4Conn) SetFlagTTL() error {
err := c.c.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true)
if runtime.GOOS == "windows" {
return nil
}
return err
}
func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
var ttl int
n, cm, src, err := c.c.IPv4PacketConn().ReadFrom(b)
if cm != nil {
ttl = cm.TTL
}
return n, ttl, src, err
}
func (c icmpv4Conn) ICMPRequestType() icmp.Type {
return ipv4.ICMPTypeEcho
}
type icmpV6Conn struct {
icmpConn
}
func (c *icmpV6Conn) SetFlagTTL() error {
err := c.c.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true)
if runtime.GOOS == "windows" {
return nil
}
return err
}
func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
var ttl int
n, cm, src, err := c.c.IPv6PacketConn().ReadFrom(b)
if cm != nil {
ttl = cm.HopLimit
}
return n, ttl, src, err
}
func (c icmpV6Conn) ICMPRequestType() icmp.Type {
return ipv6.ICMPTypeEchoRequest
}

125
ping.go
View File

@ -61,7 +61,6 @@ import (
"math"
"math/rand"
"net"
"runtime"
"sync"
"sync/atomic"
"syscall"
@ -70,6 +69,7 @@ import (
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sync/errgroup"
)
const (
@ -380,12 +380,7 @@ func (p *Pinger) SetLogger(logger Logger) {
// done. If Count or Interval are not specified, it will run continuously until
// it is interrupted.
func (p *Pinger) Run() error {
logger := p.logger
if logger == nil {
logger = NoopLogger{}
}
var conn *icmp.PacketConn
var conn packetConn
var err error
if p.ipaddr == nil {
err = p.Resolve()
@ -393,46 +388,60 @@ func (p *Pinger) Run() error {
if err != nil {
return err
}
if p.ipv4 {
if conn, err = p.listen(ipv4Proto[p.protocol]); err != nil {
return err
}
if err = conn.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true); runtime.GOOS != "windows" && err != nil {
return err
}
} else {
if conn, err = p.listen(ipv6Proto[p.protocol]); err != nil {
return err
}
if err = conn.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true); runtime.GOOS != "windows" && err != nil {
return err
}
if conn, err = p.listen(); err != nil {
return err
}
defer conn.Close()
return p.run(conn)
}
func (p *Pinger) run(conn packetConn) error {
if err := conn.SetFlagTTL(); err != nil {
return err
}
defer p.finish()
var wg sync.WaitGroup
recv := make(chan *packet, 5)
defer close(recv)
wg.Add(1)
//nolint:errcheck
go p.recvICMP(conn, recv, &wg)
if handler := p.OnSetup; handler != nil {
handler()
}
var g errgroup.Group
g.Go(func() error {
defer p.Stop()
return p.recvICMP(conn, recv)
})
g.Go(func() error {
defer p.Stop()
return p.runLoop(conn, recv)
})
return g.Wait()
}
func (p *Pinger) runLoop(
conn packetConn,
recvCh <-chan *packet,
) error {
logger := p.logger
if logger == nil {
logger = NoopLogger{}
}
timeout := time.NewTicker(p.Timeout)
interval := time.NewTicker(p.Interval)
defer func() {
p.Stop()
interval.Stop()
timeout.Stop()
wg.Wait()
}()
err = p.sendICMP(conn)
if err != nil {
if err := p.sendICMP(conn); err != nil {
return err
}
@ -440,20 +449,23 @@ func (p *Pinger) Run() error {
select {
case <-p.done:
return nil
case <-timeout.C:
return nil
case r := <-recv:
case r := <-recvCh:
err := p.processPacket(r)
if err != nil {
// FIXME: this logs as FATAL but continues
logger.Fatalf("processing received packet: %s", err)
}
case <-interval.C:
if p.Count > 0 && p.PacketsSent >= p.Count {
interval.Stop()
continue
}
err = p.sendICMP(conn)
err := p.sendICMP(conn)
if err != nil {
// FIXME: this logs as FATAL but continues
logger.Fatalf("sending packet: %s", err)
@ -531,12 +543,9 @@ func newExpBackoff(baseDelay time.Duration, maxExp int64) expBackoff {
}
func (p *Pinger) recvICMP(
conn *icmp.PacketConn,
conn packetConn,
recv chan<- *packet,
wg *sync.WaitGroup,
) error {
defer wg.Done()
// Start by waiting for 50 µs and increase to a possible maximum of ~ 100 ms.
expBackoff := newExpBackoff(50*time.Microsecond, 11)
delay := expBackoff.Get()
@ -552,30 +561,16 @@ func (p *Pinger) recvICMP(
}
var n, ttl int
var err error
if p.ipv4 {
var cm *ipv4.ControlMessage
n, cm, _, err = conn.IPv4PacketConn().ReadFrom(bytes)
if cm != nil {
ttl = cm.TTL
}
} else {
var cm *ipv6.ControlMessage
n, cm, _, err = conn.IPv6PacketConn().ReadFrom(bytes)
if cm != nil {
ttl = cm.HopLimit
}
}
n, ttl, _, err = conn.ReadFrom(bytes)
if err != nil {
if neterr, ok := err.(*net.OpError); ok {
if neterr.Timeout() {
// Read timeout
delay = expBackoff.Get()
continue
} else {
p.Stop()
return err
}
}
return err
}
select {
@ -658,14 +653,7 @@ func (p *Pinger) processPacket(recv *packet) error {
return nil
}
func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
var typ icmp.Type
if p.ipv4 {
typ = ipv4.ICMPTypeEcho
} else {
typ = ipv6.ICMPTypeEchoRequest
}
func (p *Pinger) sendICMP(conn packetConn) error {
var dst net.Addr = p.ipaddr
if p.protocol == "udp" {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
@ -683,7 +671,7 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
}
msg := &icmp.Message{
Type: typ,
Type: conn.ICMPRequestType(),
Code: 0,
Body: body,
}
@ -700,6 +688,7 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
continue
}
}
return err
}
handler := p.OnSend
if handler != nil {
@ -721,8 +710,22 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
return nil
}
func (p *Pinger) listen(netProto string) (*icmp.PacketConn, error) {
conn, err := icmp.ListenPacket(netProto, p.Source)
func (p *Pinger) listen() (packetConn, error) {
var (
conn packetConn
err error
)
if p.ipv4 {
var c icmpv4Conn
c.c, err = icmp.ListenPacket(ipv4Proto[p.protocol], p.Source)
conn = &c
} else {
var c icmpV6Conn
c.c, err = icmp.ListenPacket(ipv6Proto[p.protocol], p.Source)
conn = &c
}
if err != nil {
p.Stop()
return nil, err

View File

@ -2,8 +2,10 @@ package ping
import (
"bytes"
"errors"
"net"
"runtime/debug"
"sync/atomic"
"testing"
"time"
@ -466,6 +468,7 @@ func makeTestPinger() *Pinger {
}
func AssertNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Errorf("Expected No Error but got %s, Stack:\n%s",
err, string(debug.Stack()))
@ -473,6 +476,7 @@ func AssertNoError(t *testing.T, err error) {
}
func AssertError(t *testing.T, err error, info string) {
t.Helper()
if err == nil {
t.Errorf("Expected Error but got %s, %s, Stack:\n%s",
err, info, string(debug.Stack()))
@ -480,6 +484,7 @@ func AssertError(t *testing.T, err error, info string) {
}
func AssertEqualStrings(t *testing.T, expected, actual string) {
t.Helper()
if expected != actual {
t.Errorf("Expected %s, got %s, Stack:\n%s",
expected, actual, string(debug.Stack()))
@ -487,6 +492,7 @@ func AssertEqualStrings(t *testing.T, expected, actual string) {
}
func AssertNotEqualStrings(t *testing.T, expected, actual string) {
t.Helper()
if expected == actual {
t.Errorf("Expected %s, got %s, Stack:\n%s",
expected, actual, string(debug.Stack()))
@ -494,12 +500,14 @@ func AssertNotEqualStrings(t *testing.T, expected, actual string) {
}
func AssertTrue(t *testing.T, b bool) {
t.Helper()
if !b {
t.Errorf("Expected True, got False, Stack:\n%s", string(debug.Stack()))
}
}
func AssertFalse(t *testing.T, b bool) {
t.Helper()
if b {
t.Errorf("Expected False, got True, Stack:\n%s", string(debug.Stack()))
}
@ -596,3 +604,131 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
AssertTrue(t, dups == 1)
AssertTrue(t, pinger.PacketsRecvDuplicates == 1)
}
type testPacketConn struct{}
func (c testPacketConn) Close() error { return nil }
func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho }
func (c testPacketConn) SetFlagTTL() error { return nil }
func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, nil, nil
}
func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) {
return len(b), nil
}
type testPacketConnBadWrite struct {
testPacketConn
}
func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) {
return 0, errors.New("bad write")
}
func TestRunBadWrite(t *testing.T) {
pinger := New("127.0.0.1")
pinger.Count = 1
err := pinger.Resolve()
AssertNoError(t, err)
var conn testPacketConnBadWrite
err = pinger.run(conn)
AssertTrue(t, err != nil)
stats := pinger.Statistics()
AssertTrue(t, stats != nil)
if stats == nil {
t.FailNow()
}
AssertTrue(t, stats.PacketsSent == 0)
AssertTrue(t, stats.PacketsRecv == 0)
}
type testPacketConnBadRead struct {
testPacketConn
}
func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, nil, errors.New("bad read")
}
func TestRunBadRead(t *testing.T) {
pinger := New("127.0.0.1")
pinger.Count = 1
err := pinger.Resolve()
AssertNoError(t, err)
var conn testPacketConnBadRead
err = pinger.run(conn)
AssertTrue(t, err != nil)
stats := pinger.Statistics()
AssertTrue(t, stats != nil)
if stats == nil {
t.FailNow()
}
AssertTrue(t, stats.PacketsSent == 1)
AssertTrue(t, stats.PacketsRecv == 0)
}
type testPacketConnOK struct {
testPacketConn
writeDone int32
buf []byte
dst net.Addr
}
func (c *testPacketConnOK) WriteTo(b []byte, dst net.Addr) (int, error) {
c.buf = make([]byte, len(b))
c.dst = dst
n := copy(c.buf, b)
atomic.StoreInt32(&c.writeDone, 1)
return n, nil
}
func (c *testPacketConnOK) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
if atomic.LoadInt32(&c.writeDone) == 0 {
return 0, 0, nil, nil
}
msg, err := icmp.ParseMessage(ipv4.ICMPTypeEcho.Protocol(), c.buf)
if err != nil {
return 0, 0, nil, err
}
msg.Type = ipv4.ICMPTypeEchoReply
buf, err := msg.Marshal(nil)
if err != nil {
return 0, 0, nil, err
}
time.Sleep(10 * time.Millisecond)
return copy(b, buf), 64, c.dst, nil
}
func TestRunOK(t *testing.T) {
pinger := New("127.0.0.1")
pinger.Count = 1
err := pinger.Resolve()
AssertNoError(t, err)
conn := new(testPacketConnOK)
err = pinger.run(conn)
AssertTrue(t, err == nil)
stats := pinger.Statistics()
AssertTrue(t, stats != nil)
if stats == nil {
t.FailNow()
}
AssertTrue(t, stats.PacketsSent == 1)
AssertTrue(t, stats.PacketsRecv == 1)
AssertTrue(t, stats.MinRtt >= 10*time.Millisecond)
AssertTrue(t, stats.MinRtt <= 12*time.Millisecond)
}