diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index 0bed3fc..cecd331 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -13,7 +13,7 @@ import ( var usage = ` Usage: - ping [-c count] [-i interval] [-t timeout] [--privileged] host + ping [-c count] [-i interval] [-t timeout] [-I iface] [--privileged] host Examples: @@ -37,11 +37,12 @@ Examples: ` func main() { - timeout := flag.Duration("t", time.Second*100000, "") - interval := flag.Duration("i", time.Second, "") - count := flag.Int("c", -1, "") - size := flag.Int("s", 24, "") - ttl := flag.Int("l", 64, "TTL") + timeout := flag.Duration("t", time.Second*100000, "time to wait for response") + interval := flag.Duration("i", time.Second, "interval between sending each packet") + count := flag.Int("c", -1, "stop after replies") + size := flag.Int("s", 24, "number of data bytes to be sent") + ttl := flag.Int("l", 64, "define time to live") + iface := flag.String("I", "", "interface name") privileged := flag.Bool("privileged", false, "") flag.Usage = func() { fmt.Print(usage) @@ -90,6 +91,7 @@ func main() { pinger.Interval = *interval pinger.Timeout = *timeout pinger.TTL = *ttl + pinger.Iface = *iface pinger.SetPrivileged(*privileged) fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr()) diff --git a/packetconn.go b/packetconn.go index 38e17e3..0c9bfd9 100644 --- a/packetconn.go +++ b/packetconn.go @@ -1,8 +1,11 @@ package ping import ( + "errors" "net" + "reflect" "runtime" + "syscall" "time" "golang.org/x/net/icmp" @@ -18,6 +21,7 @@ type packetConn interface { SetReadDeadline(t time.Time) error WriteTo(b []byte, dst net.Addr) (int, error) SetTTL(ttl int) + BindToDevice(iface string) error } type icmpConn struct { @@ -37,6 +41,38 @@ func (c *icmpConn) SetReadDeadline(t time.Time) error { return c.c.SetReadDeadline(t) } +func getConnFD(conn *icmp.PacketConn) (fd int) { + var packetConn reflect.Value + + defer func() { + if r := recover(); r != nil { + fd = -1 + } + }() + + if conn.IPv4PacketConn() != nil { + packetConn = reflect.ValueOf(conn.IPv4PacketConn().PacketConn) + } else if conn.IPv6PacketConn() != nil { + packetConn = reflect.ValueOf(conn.IPv6PacketConn().PacketConn) + } else { + return -1 + } + + netFD := reflect.Indirect(reflect.Indirect(packetConn).FieldByName("fd")) + pollFD := netFD.FieldByName("pfd") + systemFD := pollFD.FieldByName("Sysfd") + return int(systemFD.Int()) +} + +func (c *icmpConn) BindToDevice(ifName string) error { + if runtime.GOOS == "linux" { + if fd := getConnFD(c.c); fd >= 0 { + return syscall.BindToDevice(fd, ifName) + } + } + return errors.New("bind to interface unsupported") // FIXME: or nil +} + func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) { if c.c.IPv6PacketConn() != nil { if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil { diff --git a/ping.go b/ping.go index e1c09ee..8f2e30a 100644 --- a/ping.go +++ b/ping.go @@ -182,6 +182,9 @@ type Pinger struct { // Source is the source IP address Source string + // Iface used to send/recv ICMP messages + Iface string + // Channel and mutex used to communicate when the Pinger should stop between goroutines. done chan interface{} lock sync.Mutex @@ -418,6 +421,12 @@ func (p *Pinger) Run() error { } defer conn.Close() + if p.Iface != "" { + if err = conn.BindToDevice(p.Iface); err != nil { + return err + } + } + conn.SetTTL(p.TTL) return p.run(conn) } diff --git a/ping_test.go b/ping_test.go index b8755e7..6ae4f91 100644 --- a/ping_test.go +++ b/ping_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "runtime" "runtime/debug" "sync/atomic" "testing" @@ -385,6 +386,27 @@ func TestSetIPAddr(t *testing.T) { AssertEqualStrings(t, googleaddr.String(), p.Addr()) } +func TestBindToDevice(t *testing.T) { + // Create a localhost ipv4 pinger + pinger := New("127.0.0.1") + pinger.ipv4 = true + pinger.Count = 1 + + // Set loopback interface: "lo" + pinger.Iface = "lo" + err := pinger.Run() + if runtime.GOOS == "linux" { + AssertNoError(t, err) + } else { + AssertError(t, err, "other platforms unsupported this feature") + } + + // Set fake interface: "L()0pB@cK" + pinger.Iface = "L()0pB@cK" + err = pinger.Run() + AssertError(t, err, "device not found") +} + func TestEmptyIPAddr(t *testing.T) { _, err := NewPinger("") AssertError(t, err, "empty pinger did not return an error") @@ -642,6 +664,7 @@ func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTyp func (c testPacketConn) SetFlagTTL() error { return nil } func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil } func (c testPacketConn) SetTTL(t int) {} +func (c testPacketConn) BindToDevice(iface string) error { return nil } func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { return 0, 0, nil, nil diff --git a/utils_linux.go b/utils_linux.go index ba785d2..0ba7677 100644 --- a/utils_linux.go +++ b/utils_linux.go @@ -1,7 +1,15 @@ -// +build linux +//go:build linux package ping +import ( + "errors" + "reflect" + "syscall" + + "golang.org/x/net/icmp" +) + // Returns the length of an ICMP message. func (p *Pinger) getMessageLength() int { return p.Size + 8 @@ -17,3 +25,33 @@ func (p *Pinger) matchID(ID int) bool { } return true } + +func getConnFD(conn *icmp.PacketConn) (fd int) { + var packetConn reflect.Value + + defer func() { + if r := recover(); r != nil { + fd = -1 + } + }() + + if conn.IPv4PacketConn() != nil { + packetConn = reflect.ValueOf(conn.IPv4PacketConn().PacketConn) + } else if conn.IPv6PacketConn() != nil { + packetConn = reflect.ValueOf(conn.IPv6PacketConn().PacketConn) + } else { + return -1 + } + + netFD := reflect.Indirect(reflect.Indirect(packetConn).FieldByName("fd")) + pollFD := netFD.FieldByName("pfd") + systemFD := pollFD.FieldByName("Sysfd") + return int(systemFD.Int()) +} + +func (c *icmpConn) BindToDevice(ifName string) error { + if fd := getConnFD(c.c); fd >= 0 { + return syscall.BindToDevice(fd, ifName) + } + return errors.New("bind to interface unsupported") +} diff --git a/utils_other.go b/utils_other.go index 6ccbe78..1bdbdb0 100644 --- a/utils_other.go +++ b/utils_other.go @@ -1,4 +1,4 @@ -// +build !linux,!windows +//go:build !linux && !windows package ping @@ -14,3 +14,7 @@ func (p *Pinger) matchID(ID int) bool { } return true } + +func (c *icmpConn) BindToDevice(ifName string) error { + return errors.New("bind to interface unsupported") +} diff --git a/utils_windows.go b/utils_windows.go index ba642bc..a00ffb7 100644 --- a/utils_windows.go +++ b/utils_windows.go @@ -1,4 +1,4 @@ -// +build windows +//go:build windows package ping @@ -22,3 +22,7 @@ func (p *Pinger) matchID(ID int) bool { } return true } + +func (c *icmpConn) BindToDevice(ifName string) error { + return errors.New("bind to interface unsupported") +}