Create a Run API with a context object

Signed-off-by: TheRushingWookie <3181551+TheRushingWookie@users.noreply.github.com>
This commit is contained in:
TheRushingWookie 2023-04-03 17:40:17 -05:00 committed by Kaj Niemi
parent a4fbe74944
commit 5e08633e1b
2 changed files with 70 additions and 6 deletions

23
ping.go
View File

@ -53,6 +53,7 @@ package ping
import (
"bytes"
"context"
"errors"
"fmt"
"log"
@ -420,6 +421,13 @@ func (p *Pinger) Mark() uint {
// done. If Count or Interval are not specified, it will run continuously until
// it is interrupted.
func (p *Pinger) Run() error {
return p.RunWithContext(context.Background())
}
// RunWithContext runs the pinger with a context. This is a blocking function that will exit when it's
// done or if the context is canceled. If Count or Interval are not specified, it will run continuously until
// it is interrupted.
func (p *Pinger) RunWithContext(ctx context.Context) error {
var conn packetConn
var err error
if p.Size < timeSliceLength+trackerLength {
@ -444,10 +452,10 @@ func (p *Pinger) Run() error {
conn.SetTTL(p.TTL)
conn.SetTOS(p.TOS)
return p.run(conn)
return p.run(ctx, conn)
}
func (p *Pinger) run(conn packetConn) error {
func (p *Pinger) run(ctx context.Context, conn packetConn) error {
if err := conn.SetFlagTTL(); err != nil {
return err
}
@ -460,7 +468,16 @@ func (p *Pinger) run(conn packetConn) error {
handler()
}
var g errgroup.Group
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
select {
case <-ctx.Done():
p.Stop()
case <-p.done:
}
return nil
})
g.Go(func() error {
defer p.Stop()

View File

@ -2,6 +2,7 @@ package ping
import (
"bytes"
"context"
"errors"
"net"
"runtime/debug"
@ -686,7 +687,7 @@ func TestRunBadWrite(t *testing.T) {
var conn testPacketConnBadWrite
err = pinger.run(conn)
err = pinger.run(context.Background(), conn)
AssertTrue(t, err != nil)
stats := pinger.Statistics()
@ -715,7 +716,7 @@ func TestRunBadRead(t *testing.T) {
var conn testPacketConnBadRead
err = pinger.run(conn)
err = pinger.run(context.Background(), conn)
AssertTrue(t, err != nil)
stats := pinger.Statistics()
@ -773,7 +774,7 @@ func TestRunOK(t *testing.T) {
conn := new(testPacketConnOK)
err = pinger.run(conn)
err = pinger.run(context.Background(), conn)
AssertTrue(t, err == nil)
stats := pinger.Statistics()
@ -786,3 +787,49 @@ func TestRunOK(t *testing.T) {
AssertTrue(t, stats.MinRtt >= 10*time.Millisecond)
AssertTrue(t, stats.MinRtt <= 12*time.Millisecond)
}
func TestRunWithTimeoutContext(t *testing.T) {
pinger := New("127.0.0.1")
err := pinger.Resolve()
AssertNoError(t, err)
conn := new(testPacketConnOK)
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err = pinger.run(ctx, conn)
AssertTrue(t, err == nil)
elapsedTime := time.Since(start)
AssertTrue(t, elapsedTime < 10*time.Second)
stats := pinger.Statistics()
AssertTrue(t, stats != nil)
if stats == nil {
t.FailNow()
}
AssertTrue(t, stats.PacketsSent > 0)
AssertTrue(t, stats.PacketsRecv > 0)
}
func TestRunWithBackgroundContext(t *testing.T) {
pinger := New("127.0.0.1")
pinger.Count = 10
pinger.Interval = 100 * time.Millisecond
err := pinger.Resolve()
AssertNoError(t, err)
conn := new(testPacketConnOK)
err = pinger.run(context.Background(), conn)
AssertTrue(t, err == nil)
stats := pinger.Statistics()
AssertTrue(t, stats != nil)
if stats == nil {
t.FailNow()
}
AssertTrue(t, stats.PacketsRecv == 10)
}