Don't resolve when creating Pinger (#65)

Move the DNS resolver out of the NewPinger() function in order to allow
adjusting of IPv4 vs IPv6 DNS resolution before running. This also
allows the user to verify resolution.
* Create new `New()` method that returns a bare default struct.
* Create new `Resolve()` method.
* Call `Resolve()` from `SetAddr()`.
* Call `Resolve()` automatically from `Run()`.
* Remove unecessary private `run()` method.

Update ping command for simplifed return values of `NewPinger()`.

Signed-off-by: Ben Kochie <superq@gmail.com>
This commit is contained in:
Ben Kochie 2020-09-14 07:41:27 +02:00 committed by GitHub
parent 805de73348
commit 8e89829cd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 104 additions and 75 deletions

131
ping.go
View File

@ -67,41 +67,37 @@ const (
)
var (
ipv4Proto = map[string]string{"ip": "ip4:icmp", "udp": "udp4"}
ipv6Proto = map[string]string{"ip": "ip6:ipv6-icmp", "udp": "udp6"}
ipv4Proto = map[string]string{"icmp": "ip4:icmp", "udp": "udp4"}
ipv6Proto = map[string]string{"icmp": "ip6:ipv6-icmp", "udp": "udp6"}
)
// NewPinger returns a new Pinger struct pointer
func NewPinger(addr string) (*Pinger, error) {
ipaddr, err := net.ResolveIPAddr("ip", addr)
if err != nil {
return nil, err
}
var ipv4 bool
if isIPv4(ipaddr.IP) {
ipv4 = true
} else if isIPv6(ipaddr.IP) {
ipv4 = false
}
// New returns a new Pinger struct pointer.
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
return &Pinger{
ipaddr: ipaddr,
addr: addr,
Interval: time.Second,
Timeout: time.Second * 100000,
Count: -1,
id: r.Intn(math.MaxInt16),
network: "udp",
ipv4: ipv4,
Interval: time.Second,
Size: timeSliceLength,
Timeout: time.Second * 100000,
Tracker: r.Int63n(math.MaxInt64),
addr: addr,
done: make(chan bool),
}, nil
id: r.Intn(math.MaxInt16),
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
}
}
// Pinger represents ICMP packet sender/receiver
// NewPinger returns a new Pinger and resolves the address.
func NewPinger(addr string) (*Pinger, error) {
p := New(addr)
return p, p.Resolve()
}
// Pinger represents a packet sender/receiver.
type Pinger struct {
// Interval is the wait time between each packet send. Default is 1s.
Interval time.Duration
@ -152,7 +148,10 @@ type Pinger struct {
size int
id int
sequence int
network string
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
protocol string
}
type packet struct {
@ -219,16 +218,10 @@ type Statistics struct {
// SetIPAddr sets the ip address of the target host.
func (p *Pinger) SetIPAddr(ipaddr *net.IPAddr) {
var ipv4 bool
if isIPv4(ipaddr.IP) {
ipv4 = true
} else if isIPv6(ipaddr.IP) {
ipv4 = false
}
p.ipv4 = isIPv4(ipaddr.IP)
p.ipaddr = ipaddr
p.addr = ipaddr.String()
p.ipv4 = ipv4
}
// IPAddr returns the ip address of the target host.
@ -236,16 +229,30 @@ func (p *Pinger) IPAddr() *net.IPAddr {
return p.ipaddr
}
// SetAddr resolves and sets the ip address of the target host, addr can be a
// DNS name like "www.google.com" or IP like "127.0.0.1".
func (p *Pinger) SetAddr(addr string) error {
ipaddr, err := net.ResolveIPAddr("ip", addr)
// Resolve does the DNS lookup for the Pinger address and sets IP protocol.
func (p *Pinger) Resolve() error {
ipaddr, err := net.ResolveIPAddr(p.network, p.addr)
if err != nil {
return err
}
p.SetIPAddr(ipaddr)
p.ipv4 = isIPv4(ipaddr.IP)
p.ipaddr = ipaddr
return nil
}
// SetAddr resolves and sets the ip address of the target host, addr can be a
// DNS name like "www.google.com" or IP like "127.0.0.1".
func (p *Pinger) SetAddr(addr string) error {
oldAddr := p.addr
p.addr = addr
err := p.Resolve()
if err != nil {
p.addr = oldAddr
return err
}
return nil
}
@ -254,39 +261,57 @@ func (p *Pinger) Addr() string {
return p.addr
}
// SetNetwork allows configuration of DNS resolution.
// * "ip" will automatically select IPv4 or IPv6.
// * "ip4" will select IPv4.
// * "ip6" will select IPv6.
func (p *Pinger) SetNetwork(n string) {
switch n {
case "ip4":
p.network = "ip4"
case "ip6":
p.network = "ip6"
default:
p.network = "ip"
}
}
// SetPrivileged sets the type of ping pinger will send.
// false means pinger will send an "unprivileged" UDP ping.
// true means pinger will send a "privileged" raw ICMP ping.
// NOTE: setting to true requires that it be run with super-user privileges.
func (p *Pinger) SetPrivileged(privileged bool) {
if privileged {
p.network = "ip"
p.protocol = "icmp"
} else {
p.network = "udp"
p.protocol = "udp"
}
}
// Privileged returns whether pinger is running in privileged mode.
func (p *Pinger) Privileged() bool {
return p.network == "ip"
return p.protocol == "icmp"
}
// Run runs the pinger. This is a blocking function that will exit when it's
// done. If Count or Interval are not specified, it will run continuously until
// it is interrupted.
func (p *Pinger) Run() {
p.run()
}
func (p *Pinger) run() {
var conn *icmp.PacketConn
var err error
if p.ipaddr == nil {
err = p.Resolve()
}
if err != nil {
return
}
if p.ipv4 {
if conn = p.listen(ipv4Proto[p.network]); conn == nil {
if conn = p.listen(ipv4Proto[p.protocol]); conn == nil {
return
}
conn.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true)
} else {
if conn = p.listen(ipv6Proto[p.network]); conn == nil {
if conn = p.listen(ipv6Proto[p.protocol]); conn == nil {
return
}
conn.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true)
@ -300,7 +325,7 @@ func (p *Pinger) run() {
wg.Add(1)
go p.recvICMP(conn, recv, &wg)
err := p.sendICMP(conn)
err = p.sendICMP(conn)
if err != nil {
fmt.Println(err.Error())
}
@ -468,8 +493,8 @@ func (p *Pinger) processPacket(recv *packet) error {
switch pkt := m.Body.(type) {
case *icmp.Echo:
// If we are privileged, we can match icmp.ID
if p.network == "ip" {
// If we are priviledged, we can match icmp.ID
if p.protocol == "icmp" {
// Check if reply from same ID
if pkt.ID != p.id {
return nil
@ -514,7 +539,7 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error {
}
var dst net.Addr = p.ipaddr
if p.network == "udp" {
if p.protocol == "udp" {
dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone}
}
@ -578,10 +603,6 @@ func isIPv4(ip net.IP) bool {
return len(ip.To4()) == net.IPv4len
}
func isIPv6(ip net.IP) bool {
return len(ip) == net.IPv6len
}
func timeToBytes(t time.Time) []byte {
nsec := t.UnixNano()
b := make([]byte, 8)

View File

@ -89,7 +89,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) {
func TestProcessPacket_IDMismatch(t *testing.T) {
pinger := makeTestPinger()
pinger.network = "ip" // ID is only checked on "ip" network
pinger.protocol = "icmp" // ID is only checked on "icmp" protocol
shouldBe0 := 0
// this function should not be called because the tracker is mismatched
pinger.OnRecv = func(pkt *Packet) {
@ -226,7 +226,8 @@ func TestProcessPacket_PacketTooSmall(t *testing.T) {
}
func TestNewPingerValid(t *testing.T) {
p, err := NewPinger("www.google.com")
p := New("www.google.com")
err := p.Resolve()
AssertNoError(t, err)
AssertEqualStrings(t, "www.google.com", p.Addr())
// DNS names should resolve into IP addresses
@ -243,9 +244,10 @@ func TestNewPingerValid(t *testing.T) {
// Test setting to ipv6 address
err = p.SetAddr("ipv6.google.com")
AssertNoError(t, err)
AssertTrue(t, isIPv6(p.IPAddr().IP))
AssertFalse(t, isIPv4(p.IPAddr().IP))
p, err = NewPinger("localhost")
p = New("localhost")
err = p.Resolve()
AssertNoError(t, err)
AssertEqualStrings(t, "localhost", p.Addr())
// DNS names should resolve into IP addresses
@ -262,9 +264,10 @@ func TestNewPingerValid(t *testing.T) {
// Test setting to ipv6 address
err = p.SetAddr("ipv6.google.com")
AssertNoError(t, err)
AssertTrue(t, isIPv6(p.IPAddr().IP))
AssertFalse(t, isIPv4(p.IPAddr().IP))
p, err = NewPinger("127.0.0.1")
p = New("127.0.0.1")
err = p.Resolve()
AssertNoError(t, err)
AssertEqualStrings(t, "127.0.0.1", p.Addr())
AssertTrue(t, isIPv4(p.IPAddr().IP))
@ -279,14 +282,15 @@ func TestNewPingerValid(t *testing.T) {
// Test setting to ipv6 address
err = p.SetAddr("ipv6.google.com")
AssertNoError(t, err)
AssertTrue(t, isIPv6(p.IPAddr().IP))
AssertFalse(t, isIPv4(p.IPAddr().IP))
p, err = NewPinger("ipv6.google.com")
p = New("ipv6.google.com")
err = p.Resolve()
AssertNoError(t, err)
AssertEqualStrings(t, "ipv6.google.com", p.Addr())
// DNS names should resolve into IP addresses
AssertNotEqualStrings(t, "ipv6.google.com", p.IPAddr().String())
AssertTrue(t, isIPv6(p.IPAddr().IP))
AssertFalse(t, isIPv4(p.IPAddr().IP))
AssertFalse(t, p.Privileged())
// Test that SetPrivileged works
p.SetPrivileged(true)
@ -298,13 +302,14 @@ func TestNewPingerValid(t *testing.T) {
// Test setting to ipv6 address
err = p.SetAddr("ipv6.google.com")
AssertNoError(t, err)
AssertTrue(t, isIPv6(p.IPAddr().IP))
AssertFalse(t, isIPv4(p.IPAddr().IP))
// ipv6 localhost:
p, err = NewPinger("::1")
p = New("::1")
err = p.Resolve()
AssertNoError(t, err)
AssertEqualStrings(t, "::1", p.Addr())
AssertTrue(t, isIPv6(p.IPAddr().IP))
AssertFalse(t, isIPv4(p.IPAddr().IP))
AssertFalse(t, p.Privileged())
// Test that SetPrivileged works
p.SetPrivileged(true)
@ -316,7 +321,7 @@ func TestNewPingerValid(t *testing.T) {
// Test setting to ipv6 address
err = p.SetAddr("ipv6.google.com")
AssertNoError(t, err)
AssertTrue(t, isIPv6(p.IPAddr().IP))
AssertFalse(t, isIPv4(p.IPAddr().IP))
}
func TestNewPingerInvalid(t *testing.T) {
@ -343,7 +348,8 @@ func TestSetIPAddr(t *testing.T) {
}
// Create a localhost ipv4 pinger
p, err := NewPinger("localhost")
p := New("localhost")
err = p.Resolve()
AssertNoError(t, err)
AssertEqualStrings(t, "localhost", p.Addr())
@ -354,7 +360,8 @@ func TestSetIPAddr(t *testing.T) {
func TestStatisticsSunny(t *testing.T) {
// Create a localhost ipv4 pinger
p, err := NewPinger("localhost")
p := New("localhost")
err := p.Resolve()
AssertNoError(t, err)
AssertEqualStrings(t, "localhost", p.Addr())
@ -399,7 +406,8 @@ func TestStatisticsSunny(t *testing.T) {
func TestStatisticsLossy(t *testing.T) {
// Create a localhost ipv4 pinger
p, err := NewPinger("localhost")
p := New("localhost")
err := p.Resolve()
AssertNoError(t, err)
AssertEqualStrings(t, "localhost", p.Addr())
@ -444,11 +452,11 @@ func TestStatisticsLossy(t *testing.T) {
// Test helpers
func makeTestPinger() *Pinger {
pinger, _ := NewPinger("127.0.0.1")
pinger := New("127.0.0.1")
pinger.ipv4 = true
pinger.addr = "127.0.0.1"
pinger.network = "ip"
pinger.protocol = "icmp"
pinger.id = 123
pinger.Tracker = 456
pinger.Size = 0
@ -497,11 +505,11 @@ func AssertFalse(t *testing.T, b bool) {
}
func BenchmarkProcessPacket(b *testing.B) {
pinger, _ := NewPinger("127.0.0.1")
pinger := New("127.0.0.1")
pinger.ipv4 = true
pinger.addr = "127.0.0.1"
pinger.network = "ip4:icmp"
pinger.protocol = "ip4:icmp"
pinger.id = 123
pinger.Tracker = 456