diff --git a/README.md b/README.md index 34662af..acc7c9c 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,9 @@ This library also supports setting the `SO_MARK` socket option which is equivale flag in standard ping binaries on linux. Setting this option requires the `CAP_NET_ADMIN` capability (via `setcap` or elevated privileges). You can set a mark (ex: 100) with `pinger.SetMark(100)` in your code. +Setting the "Don't Fragment" bit is supported under Linux which is equivalent to `ping -Mdo`. +You can enable this with `pinger.SetDoNotFragment(true)`. + ### Windows You must use `pinger.SetPrivileged(true)`, otherwise you will receive diff --git a/packetconn.go b/packetconn.go index 993015b..f171b3f 100644 --- a/packetconn.go +++ b/packetconn.go @@ -19,6 +19,7 @@ type packetConn interface { WriteTo(b []byte, dst net.Addr) (int, error) SetTTL(ttl int) SetMark(m uint) error + SetDoNotFragment() error SetTOS(tos int) } diff --git a/ping.go b/ping.go index ae326b9..ff08121 100644 --- a/ping.go +++ b/ping.go @@ -85,6 +85,7 @@ var ( ipv6Proto = map[string]string{"icmp": "ip6:ipv6-icmp", "udp": "udp6"} ErrMarkNotSupported = errors.New("setting SO_MARK socket option is not supported on this platform") + ErrDFNotSupported = errors.New("setting do-not-fragment bit is not supported on this platform") ) // New returns a new Pinger struct pointer. @@ -196,6 +197,9 @@ type Pinger struct { // mark is a SO_MARK (fwmark) set on outgoing icmp packets mark uint + // df when true sets the do-not-fragment bit in the outer IP or IPv6 header + df bool + // trackerUUIDs is the list of UUIDs being used for sending packets. trackerUUIDs []uuid.UUID @@ -420,6 +424,11 @@ func (p *Pinger) Mark() uint { return p.mark } +// SetDoNotFragment sets the do-not-fragment bit in the outer IP header to the desired value. +func (p *Pinger) SetDoNotFragment(df bool) { + p.df = df +} + // Run runs the pinger. This is a blocking function that will exit when it's // done. If Count or Interval are not specified, it will run continuously until // it is interrupted. @@ -453,6 +462,12 @@ func (p *Pinger) RunWithContext(ctx context.Context) error { } } + if p.df { + if err := conn.SetDoNotFragment(); err != nil { + return fmt.Errorf("error setting do-not-fragment: %v", err) + } + } + conn.SetTTL(p.TTL) conn.SetTOS(p.TOS) return p.run(ctx, conn) diff --git a/ping_test.go b/ping_test.go index 32a9d40..3e96611 100644 --- a/ping_test.go +++ b/ping_test.go @@ -660,6 +660,7 @@ func (c testPacketConn) SetFlagTTL() error { return nil } func (c testPacketConn) SetReadDeadline(t time.Time) error { return nil } func (c testPacketConn) SetTTL(t int) {} func (c testPacketConn) SetMark(m uint) error { return nil } +func (c testPacketConn) SetDoNotFragment() error { return nil } func (c testPacketConn) SetTOS(t int) {} func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) { diff --git a/utils_linux.go b/utils_linux.go index 2461eb9..3cb953b 100644 --- a/utils_linux.go +++ b/utils_linux.go @@ -65,6 +65,42 @@ func (c *icmpV6Conn) SetMark(mark uint) error { ) } +// SetDoNotFragment sets the do-not-fragment bit in the IP header of outgoing ICMP packets. +func (c *icmpConn) SetDoNotFragment() error { + fd, err := getFD(c.c) + if err != nil { + return err + } + return os.NewSyscallError( + "setsockopt", + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_DO), + ) +} + +// SetDoNotFragment sets the do-not-fragment bit in the IP header of outgoing ICMP packets. +func (c *icmpv4Conn) SetDoNotFragment() error { + fd, err := getFD(c.icmpConn.c) + if err != nil { + return err + } + return os.NewSyscallError( + "setsockopt", + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_MTU_DISCOVER, syscall.IP_PMTUDISC_DO), + ) +} + +// SetDoNotFragment sets the do-not-fragment bit in the IPv6 header of outgoing ICMPv6 packets. +func (c *icmpV6Conn) SetDoNotFragment() error { + fd, err := getFD(c.icmpConn.c) + if err != nil { + return err + } + return os.NewSyscallError( + "setsockopt", + syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_MTU_DISCOVER, syscall.IP_PMTUDISC_DO), + ) +} + // getFD gets the system file descriptor for an icmp.PacketConn func getFD(c *icmp.PacketConn) (uintptr, error) { v := reflect.ValueOf(c).Elem().FieldByName("c").Elem() diff --git a/utils_other.go b/utils_other.go index 80a0a90..90033d5 100644 --- a/utils_other.go +++ b/utils_other.go @@ -30,3 +30,18 @@ func (c *icmpv4Conn) SetMark(mark uint) error { func (c *icmpV6Conn) SetMark(mark uint) error { return ErrMarkNotSupported } + +// SetDoNotFragment sets the do-not-fragment bit in the IP header of outgoing ICMP packets. +func (c *icmpConn) SetDoNotFragment() error { + return ErrDFNotSupported +} + +// SetDoNotFragment sets the do-not-fragment bit in the IP header of outgoing ICMP packets. +func (c *icmpv4Conn) SetDoNotFragment() error { + return ErrDFNotSupported +} + +// SetDoNotFragment sets the do-not-fragment bit in the IPv6 header of outgoing ICMPv6 packets. +func (c *icmpV6Conn) SetDoNotFragment() error { + return ErrDFNotSupported +} diff --git a/utils_windows.go b/utils_windows.go index 5a5ba34..778fffa 100644 --- a/utils_windows.go +++ b/utils_windows.go @@ -41,3 +41,18 @@ func (c *icmpv4Conn) SetMark(mark uint) error { func (c *icmpV6Conn) SetMark(mark uint) error { return ErrMarkNotSupported } + +// SetDoNotFragment sets the do-not-fragment bit in the IP header of outgoing ICMP packets. +func (c *icmpConn) SetDoNotFragment() error { + return ErrDFNotSupported +} + +// SetDoNotFragment sets the do-not-fragment bit in the IP header of outgoing ICMP packets. +func (c *icmpv4Conn) SetDoNotFragment() error { + return ErrDFNotSupported +} + +// SetDoNotFragment sets the do-not-fragment bit in the IPv6 header of outgoing ICMPv6 packets. +func (c *icmpV6Conn) SetDoNotFragment() error { + return ErrDFNotSupported +}