diff --git a/ping.go b/ping.go index c48730b..c410894 100644 --- a/ping.go +++ b/ping.go @@ -67,41 +67,37 @@ const ( ) var ( - ipv4Proto = map[string]string{"ip": "ip4:icmp", "udp": "udp4"} - ipv6Proto = map[string]string{"ip": "ip6:ipv6-icmp", "udp": "udp6"} + ipv4Proto = map[string]string{"icmp": "ip4:icmp", "udp": "udp4"} + ipv6Proto = map[string]string{"icmp": "ip6:ipv6-icmp", "udp": "udp6"} ) -// NewPinger returns a new Pinger struct pointer -func NewPinger(addr string) (*Pinger, error) { - ipaddr, err := net.ResolveIPAddr("ip", addr) - if err != nil { - return nil, err - } - - var ipv4 bool - if isIPv4(ipaddr.IP) { - ipv4 = true - } else if isIPv6(ipaddr.IP) { - ipv4 = false - } - +// New returns a new Pinger struct pointer. +func New(addr string) *Pinger { r := rand.New(rand.NewSource(time.Now().UnixNano())) return &Pinger{ - ipaddr: ipaddr, - addr: addr, - Interval: time.Second, - Timeout: time.Second * 100000, Count: -1, - id: r.Intn(math.MaxInt16), - network: "udp", - ipv4: ipv4, + Interval: time.Second, Size: timeSliceLength, + Timeout: time.Second * 100000, Tracker: r.Int63n(math.MaxInt64), + + addr: addr, done: make(chan bool), - }, nil + id: r.Intn(math.MaxInt16), + ipaddr: nil, + ipv4: false, + network: "ip", + protocol: "udp", + } } -// Pinger represents ICMP packet sender/receiver +// NewPinger returns a new Pinger and resolves the address. +func NewPinger(addr string) (*Pinger, error) { + p := New(addr) + return p, p.Resolve() +} + +// Pinger represents a packet sender/receiver. type Pinger struct { // Interval is the wait time between each packet send. Default is 1s. Interval time.Duration @@ -152,7 +148,10 @@ type Pinger struct { size int id int sequence int - network string + // network is one of "ip", "ip4", or "ip6". + network string + // protocol is "icmp" or "udp". + protocol string } type packet struct { @@ -219,16 +218,10 @@ type Statistics struct { // SetIPAddr sets the ip address of the target host. func (p *Pinger) SetIPAddr(ipaddr *net.IPAddr) { - var ipv4 bool - if isIPv4(ipaddr.IP) { - ipv4 = true - } else if isIPv6(ipaddr.IP) { - ipv4 = false - } + p.ipv4 = isIPv4(ipaddr.IP) p.ipaddr = ipaddr p.addr = ipaddr.String() - p.ipv4 = ipv4 } // IPAddr returns the ip address of the target host. @@ -236,16 +229,30 @@ func (p *Pinger) IPAddr() *net.IPAddr { return p.ipaddr } -// SetAddr resolves and sets the ip address of the target host, addr can be a -// DNS name like "www.google.com" or IP like "127.0.0.1". -func (p *Pinger) SetAddr(addr string) error { - ipaddr, err := net.ResolveIPAddr("ip", addr) +// Resolve does the DNS lookup for the Pinger address and sets IP protocol. +func (p *Pinger) Resolve() error { + ipaddr, err := net.ResolveIPAddr(p.network, p.addr) if err != nil { return err } - p.SetIPAddr(ipaddr) + p.ipv4 = isIPv4(ipaddr.IP) + + p.ipaddr = ipaddr + + return nil +} + +// SetAddr resolves and sets the ip address of the target host, addr can be a +// DNS name like "www.google.com" or IP like "127.0.0.1". +func (p *Pinger) SetAddr(addr string) error { + oldAddr := p.addr p.addr = addr + err := p.Resolve() + if err != nil { + p.addr = oldAddr + return err + } return nil } @@ -254,39 +261,57 @@ func (p *Pinger) Addr() string { return p.addr } +// SetNetwork allows configuration of DNS resolution. +// * "ip" will automatically select IPv4 or IPv6. +// * "ip4" will select IPv4. +// * "ip6" will select IPv6. +func (p *Pinger) SetNetwork(n string) { + switch n { + case "ip4": + p.network = "ip4" + case "ip6": + p.network = "ip6" + default: + p.network = "ip" + } +} + // SetPrivileged sets the type of ping pinger will send. // false means pinger will send an "unprivileged" UDP ping. // true means pinger will send a "privileged" raw ICMP ping. // NOTE: setting to true requires that it be run with super-user privileges. func (p *Pinger) SetPrivileged(privileged bool) { if privileged { - p.network = "ip" + p.protocol = "icmp" } else { - p.network = "udp" + p.protocol = "udp" } } // Privileged returns whether pinger is running in privileged mode. func (p *Pinger) Privileged() bool { - return p.network == "ip" + return p.protocol == "icmp" } // Run runs the pinger. This is a blocking function that will exit when it's // done. If Count or Interval are not specified, it will run continuously until // it is interrupted. func (p *Pinger) Run() { - p.run() -} - -func (p *Pinger) run() { var conn *icmp.PacketConn + var err error + if p.ipaddr == nil { + err = p.Resolve() + } + if err != nil { + return + } if p.ipv4 { - if conn = p.listen(ipv4Proto[p.network]); conn == nil { + if conn = p.listen(ipv4Proto[p.protocol]); conn == nil { return } conn.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true) } else { - if conn = p.listen(ipv6Proto[p.network]); conn == nil { + if conn = p.listen(ipv6Proto[p.protocol]); conn == nil { return } conn.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true) @@ -300,7 +325,7 @@ func (p *Pinger) run() { wg.Add(1) go p.recvICMP(conn, recv, &wg) - err := p.sendICMP(conn) + err = p.sendICMP(conn) if err != nil { fmt.Println(err.Error()) } @@ -468,8 +493,8 @@ func (p *Pinger) processPacket(recv *packet) error { switch pkt := m.Body.(type) { case *icmp.Echo: - // If we are privileged, we can match icmp.ID - if p.network == "ip" { + // If we are priviledged, we can match icmp.ID + if p.protocol == "icmp" { // Check if reply from same ID if pkt.ID != p.id { return nil @@ -514,7 +539,7 @@ func (p *Pinger) sendICMP(conn *icmp.PacketConn) error { } var dst net.Addr = p.ipaddr - if p.network == "udp" { + if p.protocol == "udp" { dst = &net.UDPAddr{IP: p.ipaddr.IP, Zone: p.ipaddr.Zone} } @@ -578,10 +603,6 @@ func isIPv4(ip net.IP) bool { return len(ip.To4()) == net.IPv4len } -func isIPv6(ip net.IP) bool { - return len(ip) == net.IPv6len -} - func timeToBytes(t time.Time) []byte { nsec := t.UnixNano() b := make([]byte, 8) diff --git a/ping_test.go b/ping_test.go index a337450..18973b4 100644 --- a/ping_test.go +++ b/ping_test.go @@ -89,7 +89,7 @@ func TestProcessPacket_IgnoreNonEchoReplies(t *testing.T) { func TestProcessPacket_IDMismatch(t *testing.T) { pinger := makeTestPinger() - pinger.network = "ip" // ID is only checked on "ip" network + pinger.protocol = "icmp" // ID is only checked on "icmp" protocol shouldBe0 := 0 // this function should not be called because the tracker is mismatched pinger.OnRecv = func(pkt *Packet) { @@ -226,7 +226,8 @@ func TestProcessPacket_PacketTooSmall(t *testing.T) { } func TestNewPingerValid(t *testing.T) { - p, err := NewPinger("www.google.com") + p := New("www.google.com") + err := p.Resolve() AssertNoError(t, err) AssertEqualStrings(t, "www.google.com", p.Addr()) // DNS names should resolve into IP addresses @@ -243,9 +244,10 @@ func TestNewPingerValid(t *testing.T) { // Test setting to ipv6 address err = p.SetAddr("ipv6.google.com") AssertNoError(t, err) - AssertTrue(t, isIPv6(p.IPAddr().IP)) + AssertFalse(t, isIPv4(p.IPAddr().IP)) - p, err = NewPinger("localhost") + p = New("localhost") + err = p.Resolve() AssertNoError(t, err) AssertEqualStrings(t, "localhost", p.Addr()) // DNS names should resolve into IP addresses @@ -262,9 +264,10 @@ func TestNewPingerValid(t *testing.T) { // Test setting to ipv6 address err = p.SetAddr("ipv6.google.com") AssertNoError(t, err) - AssertTrue(t, isIPv6(p.IPAddr().IP)) + AssertFalse(t, isIPv4(p.IPAddr().IP)) - p, err = NewPinger("127.0.0.1") + p = New("127.0.0.1") + err = p.Resolve() AssertNoError(t, err) AssertEqualStrings(t, "127.0.0.1", p.Addr()) AssertTrue(t, isIPv4(p.IPAddr().IP)) @@ -279,14 +282,15 @@ func TestNewPingerValid(t *testing.T) { // Test setting to ipv6 address err = p.SetAddr("ipv6.google.com") AssertNoError(t, err) - AssertTrue(t, isIPv6(p.IPAddr().IP)) + AssertFalse(t, isIPv4(p.IPAddr().IP)) - p, err = NewPinger("ipv6.google.com") + p = New("ipv6.google.com") + err = p.Resolve() AssertNoError(t, err) AssertEqualStrings(t, "ipv6.google.com", p.Addr()) // DNS names should resolve into IP addresses AssertNotEqualStrings(t, "ipv6.google.com", p.IPAddr().String()) - AssertTrue(t, isIPv6(p.IPAddr().IP)) + AssertFalse(t, isIPv4(p.IPAddr().IP)) AssertFalse(t, p.Privileged()) // Test that SetPrivileged works p.SetPrivileged(true) @@ -298,13 +302,14 @@ func TestNewPingerValid(t *testing.T) { // Test setting to ipv6 address err = p.SetAddr("ipv6.google.com") AssertNoError(t, err) - AssertTrue(t, isIPv6(p.IPAddr().IP)) + AssertFalse(t, isIPv4(p.IPAddr().IP)) // ipv6 localhost: - p, err = NewPinger("::1") + p = New("::1") + err = p.Resolve() AssertNoError(t, err) AssertEqualStrings(t, "::1", p.Addr()) - AssertTrue(t, isIPv6(p.IPAddr().IP)) + AssertFalse(t, isIPv4(p.IPAddr().IP)) AssertFalse(t, p.Privileged()) // Test that SetPrivileged works p.SetPrivileged(true) @@ -316,7 +321,7 @@ func TestNewPingerValid(t *testing.T) { // Test setting to ipv6 address err = p.SetAddr("ipv6.google.com") AssertNoError(t, err) - AssertTrue(t, isIPv6(p.IPAddr().IP)) + AssertFalse(t, isIPv4(p.IPAddr().IP)) } func TestNewPingerInvalid(t *testing.T) { @@ -343,7 +348,8 @@ func TestSetIPAddr(t *testing.T) { } // Create a localhost ipv4 pinger - p, err := NewPinger("localhost") + p := New("localhost") + err = p.Resolve() AssertNoError(t, err) AssertEqualStrings(t, "localhost", p.Addr()) @@ -354,7 +360,8 @@ func TestSetIPAddr(t *testing.T) { func TestStatisticsSunny(t *testing.T) { // Create a localhost ipv4 pinger - p, err := NewPinger("localhost") + p := New("localhost") + err := p.Resolve() AssertNoError(t, err) AssertEqualStrings(t, "localhost", p.Addr()) @@ -399,7 +406,8 @@ func TestStatisticsSunny(t *testing.T) { func TestStatisticsLossy(t *testing.T) { // Create a localhost ipv4 pinger - p, err := NewPinger("localhost") + p := New("localhost") + err := p.Resolve() AssertNoError(t, err) AssertEqualStrings(t, "localhost", p.Addr()) @@ -444,11 +452,11 @@ func TestStatisticsLossy(t *testing.T) { // Test helpers func makeTestPinger() *Pinger { - pinger, _ := NewPinger("127.0.0.1") + pinger := New("127.0.0.1") pinger.ipv4 = true pinger.addr = "127.0.0.1" - pinger.network = "ip" + pinger.protocol = "icmp" pinger.id = 123 pinger.Tracker = 456 pinger.Size = 0 @@ -497,11 +505,11 @@ func AssertFalse(t *testing.T, b bool) { } func BenchmarkProcessPacket(b *testing.B) { - pinger, _ := NewPinger("127.0.0.1") + pinger := New("127.0.0.1") pinger.ipv4 = true pinger.addr = "127.0.0.1" - pinger.network = "ip4:icmp" + pinger.protocol = "ip4:icmp" pinger.id = 123 pinger.Tracker = 456