Add ability to set outgouing interface (only Linux support)

This commit is contained in:
TimofeyAf 2022-02-22 23:02:39 +03:00
parent 779d1e9195
commit 3ef8b51873
7 changed files with 125 additions and 9 deletions

View File

@ -13,7 +13,7 @@ import (
var usage = ` var usage = `
Usage: Usage:
ping [-c count] [-i interval] [-t timeout] [--privileged] host ping [-c count] [-i interval] [-t timeout] [-I iface] [--privileged] host
Examples: Examples:
@ -37,11 +37,12 @@ Examples:
` `
func main() { func main() {
timeout := flag.Duration("t", time.Second*100000, "") timeout := flag.Duration("t", time.Second*100000, "time to wait for response")
interval := flag.Duration("i", time.Second, "") interval := flag.Duration("i", time.Second, "interval between sending each packet")
count := flag.Int("c", -1, "") count := flag.Int("c", -1, "stop after replies")
size := flag.Int("s", 24, "") size := flag.Int("s", 24, "number of data bytes to be sent")
ttl := flag.Int("l", 64, "TTL") ttl := flag.Int("l", 64, "define time to live")
iface := flag.String("I", "", "interface name")
privileged := flag.Bool("privileged", false, "") privileged := flag.Bool("privileged", false, "")
flag.Usage = func() { flag.Usage = func() {
fmt.Print(usage) fmt.Print(usage)
@ -90,6 +91,7 @@ func main() {
pinger.Interval = *interval pinger.Interval = *interval
pinger.Timeout = *timeout pinger.Timeout = *timeout
pinger.TTL = *ttl pinger.TTL = *ttl
pinger.Iface = *iface
pinger.SetPrivileged(*privileged) pinger.SetPrivileged(*privileged)
fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr()) fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr())

View File

@ -1,8 +1,11 @@
package ping package ping
import ( import (
"errors"
"net" "net"
"reflect"
"runtime" "runtime"
"syscall"
"time" "time"
"golang.org/x/net/icmp" "golang.org/x/net/icmp"
@ -18,6 +21,7 @@ type packetConn interface {
SetReadDeadline(t time.Time) error SetReadDeadline(t time.Time) error
WriteTo(b []byte, dst net.Addr) (int, error) WriteTo(b []byte, dst net.Addr) (int, error)
SetTTL(ttl int) SetTTL(ttl int)
BindToDevice(iface string) error
} }
type icmpConn struct { type icmpConn struct {
@ -37,6 +41,38 @@ func (c *icmpConn) SetReadDeadline(t time.Time) error {
return c.c.SetReadDeadline(t) return c.c.SetReadDeadline(t)
} }
func getConnFD(conn *icmp.PacketConn) (fd int) {
var packetConn reflect.Value
defer func() {
if r := recover(); r != nil {
fd = -1
}
}()
if conn.IPv4PacketConn() != nil {
packetConn = reflect.ValueOf(conn.IPv4PacketConn().PacketConn)
} else if conn.IPv6PacketConn() != nil {
packetConn = reflect.ValueOf(conn.IPv6PacketConn().PacketConn)
} else {
return -1
}
netFD := reflect.Indirect(reflect.Indirect(packetConn).FieldByName("fd"))
pollFD := netFD.FieldByName("pfd")
systemFD := pollFD.FieldByName("Sysfd")
return int(systemFD.Int())
}
func (c *icmpConn) BindToDevice(ifName string) error {
if runtime.GOOS == "linux" {
if fd := getConnFD(c.c); fd >= 0 {
return syscall.BindToDevice(fd, ifName)
}
}
return errors.New("bind to interface unsupported") // FIXME: or nil
}
func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) { func (c *icmpConn) WriteTo(b []byte, dst net.Addr) (int, error) {
if c.c.IPv6PacketConn() != nil { if c.c.IPv6PacketConn() != nil {
if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil { if err := c.c.IPv6PacketConn().SetHopLimit(c.ttl); err != nil {

View File

@ -182,6 +182,9 @@ type Pinger struct {
// Source is the source IP address // Source is the source IP address
Source string Source string
// Iface used to send/recv ICMP messages
Iface string
// Channel and mutex used to communicate when the Pinger should stop between goroutines. // Channel and mutex used to communicate when the Pinger should stop between goroutines.
done chan interface{} done chan interface{}
lock sync.Mutex lock sync.Mutex
@ -418,6 +421,12 @@ func (p *Pinger) Run() error {
} }
defer conn.Close() defer conn.Close()
if p.Iface != "" {
if err = conn.BindToDevice(p.Iface); err != nil {
return err
}
}
conn.SetTTL(p.TTL) conn.SetTTL(p.TTL)
return p.run(conn) return p.run(conn)
} }

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"runtime"
"runtime/debug" "runtime/debug"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -385,6 +386,27 @@ func TestSetIPAddr(t *testing.T) {
AssertEqualStrings(t, googleaddr.String(), p.Addr()) AssertEqualStrings(t, googleaddr.String(), p.Addr())
} }
func TestBindToDevice(t *testing.T) {
// Create a localhost ipv4 pinger
pinger := New("127.0.0.1")
pinger.ipv4 = true
pinger.Count = 1
// Set loopback interface: "lo"
pinger.Iface = "lo"
err := pinger.Run()
if runtime.GOOS == "linux" {
AssertNoError(t, err)
} else {
AssertError(t, err, "other platforms unsupported this feature")
}
// Set fake interface: "L()0pB@cK"
pinger.Iface = "L()0pB@cK"
err = pinger.Run()
AssertError(t, err, "device not found")
}
func TestEmptyIPAddr(t *testing.T) { func TestEmptyIPAddr(t *testing.T) {
_, err := NewPinger("") _, err := NewPinger("")
AssertError(t, err, "empty pinger did not return an error") AssertError(t, err, "empty pinger did not return an error")
@ -642,6 +664,7 @@ func (c testPacketConn) ICMPRequestType() icmp.Type { return ipv4.ICMPTyp
func (c testPacketConn) SetFlagTTL() error { return nil } func (c testPacketConn) SetFlagTTL() error { return nil }
func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil } func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (c testPacketConn) SetTTL(t int) {} func (c testPacketConn) SetTTL(t int) {}
func (c testPacketConn) BindToDevice(iface string) error { return nil }
func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, nil, nil return 0, 0, nil, nil

View File

@ -1,7 +1,15 @@
// +build linux //go:build linux
package ping package ping
import (
"errors"
"reflect"
"syscall"
"golang.org/x/net/icmp"
)
// Returns the length of an ICMP message. // Returns the length of an ICMP message.
func (p *Pinger) getMessageLength() int { func (p *Pinger) getMessageLength() int {
return p.Size + 8 return p.Size + 8
@ -17,3 +25,33 @@ func (p *Pinger) matchID(ID int) bool {
} }
return true return true
} }
func getConnFD(conn *icmp.PacketConn) (fd int) {
var packetConn reflect.Value
defer func() {
if r := recover(); r != nil {
fd = -1
}
}()
if conn.IPv4PacketConn() != nil {
packetConn = reflect.ValueOf(conn.IPv4PacketConn().PacketConn)
} else if conn.IPv6PacketConn() != nil {
packetConn = reflect.ValueOf(conn.IPv6PacketConn().PacketConn)
} else {
return -1
}
netFD := reflect.Indirect(reflect.Indirect(packetConn).FieldByName("fd"))
pollFD := netFD.FieldByName("pfd")
systemFD := pollFD.FieldByName("Sysfd")
return int(systemFD.Int())
}
func (c *icmpConn) BindToDevice(ifName string) error {
if fd := getConnFD(c.c); fd >= 0 {
return syscall.BindToDevice(fd, ifName)
}
return errors.New("bind to interface unsupported")
}

View File

@ -1,4 +1,4 @@
// +build !linux,!windows //go:build !linux && !windows
package ping package ping
@ -14,3 +14,7 @@ func (p *Pinger) matchID(ID int) bool {
} }
return true return true
} }
func (c *icmpConn) BindToDevice(ifName string) error {
return errors.New("bind to interface unsupported")
}

View File

@ -1,4 +1,4 @@
// +build windows //go:build windows
package ping package ping
@ -22,3 +22,7 @@ func (p *Pinger) matchID(ID int) bool {
} }
return true return true
} }
func (c *icmpConn) BindToDevice(ifName string) error {
return errors.New("bind to interface unsupported")
}