diff --git a/ping.go b/ping.go index 1419dc6..bd4301e 100644 --- a/ping.go +++ b/ping.go @@ -53,6 +53,7 @@ package ping import ( "bytes" + "context" "errors" "fmt" "log" @@ -420,6 +421,13 @@ func (p *Pinger) Mark() uint { // done. If Count or Interval are not specified, it will run continuously until // it is interrupted. func (p *Pinger) Run() error { + return p.RunWithContext(context.Background()) +} + +// RunWithContext runs the pinger with a context. This is a blocking function that will exit when it's +// done or if the context is canceled. If Count or Interval are not specified, it will run continuously until +// it is interrupted. +func (p *Pinger) RunWithContext(ctx context.Context) error { var conn packetConn var err error if p.Size < timeSliceLength+trackerLength { @@ -444,10 +452,10 @@ func (p *Pinger) Run() error { conn.SetTTL(p.TTL) conn.SetTOS(p.TOS) - return p.run(conn) + return p.run(ctx, conn) } -func (p *Pinger) run(conn packetConn) error { +func (p *Pinger) run(ctx context.Context, conn packetConn) error { if err := conn.SetFlagTTL(); err != nil { return err } @@ -460,7 +468,16 @@ func (p *Pinger) run(conn packetConn) error { handler() } - var g errgroup.Group + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + select { + case <-ctx.Done(): + p.Stop() + case <-p.done: + } + return nil + }) g.Go(func() error { defer p.Stop() diff --git a/ping_test.go b/ping_test.go index fe726ee..32a9d40 100644 --- a/ping_test.go +++ b/ping_test.go @@ -2,6 +2,7 @@ package ping import ( "bytes" + "context" "errors" "net" "runtime/debug" @@ -686,7 +687,7 @@ func TestRunBadWrite(t *testing.T) { var conn testPacketConnBadWrite - err = pinger.run(conn) + err = pinger.run(context.Background(), conn) AssertTrue(t, err != nil) stats := pinger.Statistics() @@ -715,7 +716,7 @@ func TestRunBadRead(t *testing.T) { var conn testPacketConnBadRead - err = pinger.run(conn) + err = pinger.run(context.Background(), conn) AssertTrue(t, err != nil) stats := pinger.Statistics() @@ -773,7 +774,7 @@ func TestRunOK(t *testing.T) { conn := new(testPacketConnOK) - err = pinger.run(conn) + err = pinger.run(context.Background(), conn) AssertTrue(t, err == nil) stats := pinger.Statistics() @@ -786,3 +787,49 @@ func TestRunOK(t *testing.T) { AssertTrue(t, stats.MinRtt >= 10*time.Millisecond) AssertTrue(t, stats.MinRtt <= 12*time.Millisecond) } + +func TestRunWithTimeoutContext(t *testing.T) { + pinger := New("127.0.0.1") + + err := pinger.Resolve() + AssertNoError(t, err) + + conn := new(testPacketConnOK) + + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + err = pinger.run(ctx, conn) + AssertTrue(t, err == nil) + elapsedTime := time.Since(start) + AssertTrue(t, elapsedTime < 10*time.Second) + + stats := pinger.Statistics() + AssertTrue(t, stats != nil) + if stats == nil { + t.FailNow() + } + AssertTrue(t, stats.PacketsSent > 0) + AssertTrue(t, stats.PacketsRecv > 0) +} + +func TestRunWithBackgroundContext(t *testing.T) { + pinger := New("127.0.0.1") + pinger.Count = 10 + pinger.Interval = 100 * time.Millisecond + + err := pinger.Resolve() + AssertNoError(t, err) + + conn := new(testPacketConnOK) + + err = pinger.run(context.Background(), conn) + AssertTrue(t, err == nil) + + stats := pinger.Statistics() + AssertTrue(t, stats != nil) + if stats == nil { + t.FailNow() + } + AssertTrue(t, stats.PacketsRecv == 10) +}