diff --git a/go.mod b/go.mod index 44ad2f7..9a6710d 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/go-ping/ping go 1.14 -require golang.org/x/net v0.0.0-20200904194848-62affa334b73 +require ( + golang.org/x/net v0.0.0-20200904194848-62affa334b73 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect +) diff --git a/go.sum b/go.sum index 6bc5e04..ab1260f 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA= golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= diff --git a/packetconn.go b/packetconn.go new file mode 100644 index 0000000..a590bfb --- /dev/null +++ b/packetconn.go @@ -0,0 +1,86 @@ +package ping + +import ( + "net" + "runtime" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +type packetConn interface { + Close() error + ICMPRequestType() icmp.Type + ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) + SetFlagTTL() error + SetReadDeadline(t time.Time) error + WriteTo(b []byte, dst net.Addr) (int, error) +} + +type icmpConn struct { + c *icmp.PacketConn +} + +func (c *icmpConn) Close() error { + return c.c.Close() +} + +func (c *icmpConn) SetReadDeadline(t time.Time) error { + return c.c.SetReadDeadline(t) +} + +func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) { + return c.c.WriteTo(b, dst) +} + +type icmpv4Conn struct { + icmpConn +} + +func (c *icmpv4Conn) SetFlagTTL() error { + err := c.c.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true) + if runtime.GOOS == "windows" { + return nil + } + return err +} + +func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { + var ttl int + n, cm, src, err := c.c.IPv4PacketConn().ReadFrom(b) + if cm != nil { + ttl = cm.TTL + } + return n, ttl, src, err +} + +func (c icmpv4Conn) ICMPRequestType() icmp.Type { + return ipv4.ICMPTypeEcho +} + +type icmpV6Conn struct { + icmpConn +} + +func (c *icmpV6Conn) SetFlagTTL() error { + err := c.c.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true) + if runtime.GOOS == "windows" { + return nil + } + return err +} + +func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) { + var ttl int + n, cm, src, err := c.c.IPv6PacketConn().ReadFrom(b) + if cm != nil { + ttl = cm.HopLimit + } + return n, ttl, src, err +} + +func (c icmpV6Conn) ICMPRequestType() icmp.Type { + return ipv6.ICMPTypeEchoRequest +} diff --git a/ping.go b/ping.go index 520efb8..cbaede8 100644 --- a/ping.go +++ b/ping.go @@ -61,7 +61,6 @@ import ( "math" "math/rand" "net" - "runtime" "sync" "sync/atomic" "syscall" @@ -70,6 +69,7 @@ import ( "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.org/x/sync/errgroup" ) const ( @@ -380,12 +380,7 @@ func (p *Pinger) SetLogger(logger Logger) { // done. If Count or Interval are not specified, it will run continuously until // it is interrupted. func (p *Pinger) Run() error { - logger := p.logger - if logger == nil { - logger = NoopLogger{} - } - - var conn *icmp.PacketConn + var conn packetConn var err error if p.ipaddr == nil { err = p.Resolve() @@ -393,46 +388,60 @@ func (p *Pinger) Run() error { if err != nil { return err } - if p.ipv4 { - if conn, err = p.listen(ipv4Proto[p.protocol]); err != nil { - return err - } - if err = conn.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true); runtime.GOOS != "windows" && err != nil { - return err - } - } else { - if conn, err = p.listen(ipv6Proto[p.protocol]); err != nil { - return err - } - if err = conn.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true); runtime.GOOS != "windows" && err != nil { - return err - } + if conn, err = p.listen(); err != nil { + return err } defer conn.Close() + + return p.run(conn) +} + +func (p *Pinger) run(conn packetConn) error { + if err := conn.SetFlagTTL(); err != nil { + return err + } defer p.finish() - var wg sync.WaitGroup recv := make(chan *packet, 5) defer close(recv) - wg.Add(1) - //nolint:errcheck - go p.recvICMP(conn, recv, &wg) if handler := p.OnSetup; handler != nil { handler() } + var g errgroup.Group + + g.Go(func() error { + defer p.Stop() + return p.recvICMP(conn, recv) + }) + + g.Go(func() error { + defer p.Stop() + return p.runLoop(conn, recv) + }) + + return g.Wait() +} + +func (p *Pinger) runLoop( + conn packetConn, + recvCh <-chan *packet, +) error { + logger := p.logger + if logger == nil { + logger = NoopLogger{} + } + timeout := time.NewTicker(p.Timeout) interval := time.NewTicker(p.Interval) defer func() { p.Stop() interval.Stop() timeout.Stop() - wg.Wait() }() - err = p.sendICMP(conn) - if err != nil { + if err := p.sendICMP(conn); err != nil { return err } @@ -440,20 +449,23 @@ func (p *Pinger) Run() error { select { case <-p.done: return nil + case <-timeout.C: return nil - case r := <-recv: + + case r := <-recvCh: err := p.processPacket(r) if err != nil { // FIXME: this logs as FATAL but continues logger.Fatalf("processing received packet: %s", err) } + case <-interval.C: if p.Count > 0 && p.PacketsSent >= p.Count { interval.Stop() continue } - err = p.sendICMP(conn) + err := p.sendICMP(conn) if err != nil { // FIXME: this logs as FATAL but continues logger.Fatalf("sending packet: %s", err) @@ -531,12 +543,9 @@ func newExpBackoff(baseDelay time.Duration, maxExp int64) expBackoff { } func (p *Pinger) recvICMP( - conn *icmp.PacketConn, + conn packetConn, recv chan<- *packet, - wg *sync.WaitGroup, ) error { - defer wg.Done() - // Start by waiting for 50 µs and increase to a possible maximum of ~ 100 ms. expBackoff := newExpBackoff(50*time.Microsecond, 11) delay := expBackoff.Get() @@ -552,30 +561,16 @@ func (p *Pinger) recvICMP( } var n, ttl int var err error - if p.ipv4 { - var cm *ipv4.ControlMessage - n, cm, _, err = conn.IPv4PacketConn().ReadFrom(bytes) - if cm != nil { - ttl = cm.TTL - } - } else { - var cm *ipv6.ControlMessage - n, cm, _, err = conn.IPv6PacketConn().ReadFrom(bytes) - if cm != nil { - ttl = cm.HopLimit - } - } + n, ttl, _, err = conn.ReadFrom(bytes) if err != nil { if neterr, ok := err.(*net.OpError); ok { if neterr.Timeout() { // Read timeout delay = expBackoff.Get() continue - } else { - p.Stop() - return err } } + return err } select { @@ -658,14 +653,7 @@ func (p *Pinger) processPacket(recv *packet) error { return nil } -func (p *Pinger) sendICMP(conn *icmp.PacketConn) error { - var typ icmp.Type - if p.ipv4 { - typ = ipv4.ICMPTypeEcho - } else { - typ = ipv6.ICMPTypeEchoRequest - } - +func (p *Pinger) sendICMP(conn packetConn) error { var dst net.Addr = p.ipaddr if p.protocol == "udp" { dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} @@ -683,7 +671,7 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error { } msg := &icmp.Message{ - Type: typ, + Type: conn.ICMPRequestType(), Code: 0, Body: body, } @@ -700,6 +688,7 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error { continue } } + return err } handler := p.OnSend if handler != nil { @@ -721,8 +710,22 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error { return nil } -func (p *Pinger) listen(netProto string) (*icmp.PacketConn, error) { - conn, err := icmp.ListenPacket(netProto, p.Source) +func (p *Pinger) listen() (packetConn, error) { + var ( + conn packetConn + err error + ) + + if p.ipv4 { + var c icmpv4Conn + c.c, err = icmp.ListenPacket(ipv4Proto[p.protocol], p.Source) + conn = &c + } else { + var c icmpV6Conn + c.c, err = icmp.ListenPacket(ipv6Proto[p.protocol], p.Source) + conn = &c + } + if err != nil { p.Stop() return nil, err diff --git a/ping_test.go b/ping_test.go index 14ff23b..cfda052 100644 --- a/ping_test.go +++ b/ping_test.go @@ -2,8 +2,10 @@ package ping import ( "bytes" + "errors" "net" "runtime/debug" + "sync/atomic" "testing" "time" @@ -466,6 +468,7 @@ func makeTestPinger() *Pinger { } func AssertNoError(t *testing.T, err error) { + t.Helper() if err != nil { t.Errorf("Expected No Error but got %s, Stack:\n%s", err, string(debug.Stack())) @@ -473,6 +476,7 @@ func AssertNoError(t *testing.T, err error) { } func AssertError(t *testing.T, err error, info string) { + t.Helper() if err == nil { t.Errorf("Expected Error but got %s, %s, Stack:\n%s", err, info, string(debug.Stack())) @@ -480,6 +484,7 @@ func AssertError(t *testing.T, err error, info string) { } func AssertEqualStrings(t *testing.T, expected, actual string) { + t.Helper() if expected != actual { t.Errorf("Expected %s, got %s, Stack:\n%s", expected, actual, string(debug.Stack())) @@ -487,6 +492,7 @@ func AssertEqualStrings(t *testing.T, expected, actual string) { } func AssertNotEqualStrings(t *testing.T, expected, actual string) { + t.Helper() if expected == actual { t.Errorf("Expected %s, got %s, Stack:\n%s", expected, actual, string(debug.Stack())) @@ -494,12 +500,14 @@ func AssertNotEqualStrings(t *testing.T, expected, actual string) { } func AssertTrue(t *testing.T, b bool) { + t.Helper() if !b { t.Errorf("Expected True, got False, Stack:\n%s", string(debug.Stack())) } } func AssertFalse(t *testing.T, b bool) { + t.Helper() if b { t.Errorf("Expected False, got True, Stack:\n%s", string(debug.Stack())) } @@ -596,3 +604,131 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) { AssertTrue(t, dups == 1) AssertTrue(t, pinger.PacketsRecvDuplicates == 1) } + +type testPacketConn struct{} + +func (c testPacketConn) Close() error { return nil } +func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTypeEcho } +func (c testPacketConn) SetFlagTTL() error { return nil } +func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil } + +func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { + return 0, 0, nil, nil +} + +func (c testPacketConn) WriteTo(b []byte, dst net.Addr) (int, error) { + return len(b), nil +} + +type testPacketConnBadWrite struct { + testPacketConn +} + +func (c testPacketConnBadWrite) WriteTo(b []byte, dst net.Addr) (int, error) { + return 0, errors.New("bad write") +} + +func TestRunBadWrite(t *testing.T) { + pinger := New("127.0.0.1") + pinger.Count = 1 + + err := pinger.Resolve() + AssertNoError(t, err) + + var conn testPacketConnBadWrite + + err = pinger.run(conn) + AssertTrue(t, err != nil) + + stats := pinger.Statistics() + AssertTrue(t, stats != nil) + if stats == nil { + t.FailNow() + } + AssertTrue(t, stats.PacketsSent == 0) + AssertTrue(t, stats.PacketsRecv == 0) +} + +type testPacketConnBadRead struct { + testPacketConn +} + +func (c testPacketConnBadRead) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { + return 0, 0, nil, errors.New("bad read") +} + +func TestRunBadRead(t *testing.T) { + pinger := New("127.0.0.1") + pinger.Count = 1 + + err := pinger.Resolve() + AssertNoError(t, err) + + var conn testPacketConnBadRead + + err = pinger.run(conn) + AssertTrue(t, err != nil) + + stats := pinger.Statistics() + AssertTrue(t, stats != nil) + if stats == nil { + t.FailNow() + } + AssertTrue(t, stats.PacketsSent == 1) + AssertTrue(t, stats.PacketsRecv == 0) +} + +type testPacketConnOK struct { + testPacketConn + writeDone int32 + buf []byte + dst net.Addr +} + +func (c *testPacketConnOK) WriteTo(b []byte, dst net.Addr) (int, error) { + c.buf = make([]byte, len(b)) + c.dst = dst + n := copy(c.buf, b) + atomic.StoreInt32(&c.writeDone, 1) + return n, nil +} + +func (c *testPacketConnOK) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { + if atomic.LoadInt32(&c.writeDone) == 0 { + return 0, 0, nil, nil + } + msg, err := icmp.ParseMessage(ipv4.ICMPTypeEcho.Protocol(), c.buf) + if err != nil { + return 0, 0, nil, err + } + msg.Type = ipv4.ICMPTypeEchoReply + buf, err := msg.Marshal(nil) + if err != nil { + return 0, 0, nil, err + } + time.Sleep(10 * time.Millisecond) + return copy(b, buf), 64, c.dst, nil +} + +func TestRunOK(t *testing.T) { + pinger := New("127.0.0.1") + pinger.Count = 1 + + err := pinger.Resolve() + AssertNoError(t, err) + + conn := new(testPacketConnOK) + + err = pinger.run(conn) + AssertTrue(t, err == nil) + + stats := pinger.Statistics() + AssertTrue(t, stats != nil) + if stats == nil { + t.FailNow() + } + AssertTrue(t, stats.PacketsSent == 1) + AssertTrue(t, stats.PacketsRecv == 1) + AssertTrue(t, stats.MinRtt >= 10*time.Millisecond) + AssertTrue(t, stats.MinRtt <= 12*time.Millisecond) +}