mirror of
https://github.com/go-ping/ping.git
synced 2025-08-01 21:49:06 +00:00
Create a Run API with a context object
Signed-off-by: TheRushingWookie <3181551+TheRushingWookie@users.noreply.github.com>
This commit is contained in:
parent
a4fbe74944
commit
5e08633e1b
23
ping.go
23
ping.go
@ -53,6 +53,7 @@ package ping
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@ -420,6 +421,13 @@ func (p *Pinger) Mark() uint {
|
||||
// done. If Count or Interval are not specified, it will run continuously until
|
||||
// it is interrupted.
|
||||
func (p *Pinger) Run() error {
|
||||
return p.RunWithContext(context.Background())
|
||||
}
|
||||
|
||||
// RunWithContext runs the pinger with a context. This is a blocking function that will exit when it's
|
||||
// done or if the context is canceled. If Count or Interval are not specified, it will run continuously until
|
||||
// it is interrupted.
|
||||
func (p *Pinger) RunWithContext(ctx context.Context) error {
|
||||
var conn packetConn
|
||||
var err error
|
||||
if p.Size < timeSliceLength+trackerLength {
|
||||
@ -444,10 +452,10 @@ func (p *Pinger) Run() error {
|
||||
|
||||
conn.SetTTL(p.TTL)
|
||||
conn.SetTOS(p.TOS)
|
||||
return p.run(conn)
|
||||
return p.run(ctx, conn)
|
||||
}
|
||||
|
||||
func (p *Pinger) run(conn packetConn) error {
|
||||
func (p *Pinger) run(ctx context.Context, conn packetConn) error {
|
||||
if err := conn.SetFlagTTL(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -460,7 +468,16 @@ func (p *Pinger) run(conn packetConn) error {
|
||||
handler()
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.Stop()
|
||||
case <-p.done:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
defer p.Stop()
|
||||
|
53
ping_test.go
53
ping_test.go
@ -2,6 +2,7 @@ package ping
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
@ -686,7 +687,7 @@ func TestRunBadWrite(t *testing.T) {
|
||||
|
||||
var conn testPacketConnBadWrite
|
||||
|
||||
err = pinger.run(conn)
|
||||
err = pinger.run(context.Background(), conn)
|
||||
AssertTrue(t, err != nil)
|
||||
|
||||
stats := pinger.Statistics()
|
||||
@ -715,7 +716,7 @@ func TestRunBadRead(t *testing.T) {
|
||||
|
||||
var conn testPacketConnBadRead
|
||||
|
||||
err = pinger.run(conn)
|
||||
err = pinger.run(context.Background(), conn)
|
||||
AssertTrue(t, err != nil)
|
||||
|
||||
stats := pinger.Statistics()
|
||||
@ -773,7 +774,7 @@ func TestRunOK(t *testing.T) {
|
||||
|
||||
conn := new(testPacketConnOK)
|
||||
|
||||
err = pinger.run(conn)
|
||||
err = pinger.run(context.Background(), conn)
|
||||
AssertTrue(t, err == nil)
|
||||
|
||||
stats := pinger.Statistics()
|
||||
@ -786,3 +787,49 @@ func TestRunOK(t *testing.T) {
|
||||
AssertTrue(t, stats.MinRtt >= 10*time.Millisecond)
|
||||
AssertTrue(t, stats.MinRtt <= 12*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRunWithTimeoutContext(t *testing.T) {
|
||||
pinger := New("127.0.0.1")
|
||||
|
||||
err := pinger.Resolve()
|
||||
AssertNoError(t, err)
|
||||
|
||||
conn := new(testPacketConnOK)
|
||||
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
err = pinger.run(ctx, conn)
|
||||
AssertTrue(t, err == nil)
|
||||
elapsedTime := time.Since(start)
|
||||
AssertTrue(t, elapsedTime < 10*time.Second)
|
||||
|
||||
stats := pinger.Statistics()
|
||||
AssertTrue(t, stats != nil)
|
||||
if stats == nil {
|
||||
t.FailNow()
|
||||
}
|
||||
AssertTrue(t, stats.PacketsSent > 0)
|
||||
AssertTrue(t, stats.PacketsRecv > 0)
|
||||
}
|
||||
|
||||
func TestRunWithBackgroundContext(t *testing.T) {
|
||||
pinger := New("127.0.0.1")
|
||||
pinger.Count = 10
|
||||
pinger.Interval = 100 * time.Millisecond
|
||||
|
||||
err := pinger.Resolve()
|
||||
AssertNoError(t, err)
|
||||
|
||||
conn := new(testPacketConnOK)
|
||||
|
||||
err = pinger.run(context.Background(), conn)
|
||||
AssertTrue(t, err == nil)
|
||||
|
||||
stats := pinger.Statistics()
|
||||
AssertTrue(t, stats != nil)
|
||||
if stats == nil {
|
||||
t.FailNow()
|
||||
}
|
||||
AssertTrue(t, stats.PacketsRecv == 10)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user