diff --git a/tools/leaderelection/example/main.go b/tools/leaderelection/example/main.go index 91511e5b..ebcc0e8d 100644 --- a/tools/leaderelection/example/main.go +++ b/tools/leaderelection/example/main.go @@ -19,9 +19,11 @@ package main import ( "context" "flag" + "fmt" "log" "os" "os/signal" + "strings" "syscall" "time" @@ -31,6 +33,7 @@ import ( "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/leaderelection/resourcelock" + "k8s.io/client-go/transport" "k8s.io/klog" ) @@ -78,6 +81,10 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // use a client that will stop allowing new requests once the context ends + config.Wrap(transport.ContextCanceller(ctx, fmt.Errorf("the leader is shutting down"))) + exampleClient := kubernetes.NewForConfigOrDie(config).CoreV1() + // listen for interrupts or the Linux SIGTERM signal and cancel // our context, which the leader election code will observe and // step down @@ -116,6 +123,12 @@ func main() { }, }) + // because the context is closed, the client should report errors + _, err = exampleClient.ConfigMaps(args[1]).Get(args[2], metav1.GetOptions{}) + if err == nil || !strings.Contains(err.Error(), "the leader is shutting down") { + log.Fatalf("%s: expected to get an error when trying to make a client call: %v", id, err) + } + // we no longer hold the lease, so perform any cleanup and then // exit log.Printf("%s: done", id) diff --git a/transport/transport.go b/transport/transport.go index f62f8003..2a145c97 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -17,6 +17,7 @@ limitations under the License. package transport import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -196,3 +197,31 @@ func Wrappers(fns ...WrapperFunc) WrapperFunc { return base } } + +// ContextCanceller prevents new requests after the provided context is finished. +// err is returned when the context is closed, allowing the caller to provide a context +// appropriate error. +func ContextCanceller(ctx context.Context, err error) WrapperFunc { + return func(rt http.RoundTripper) http.RoundTripper { + return &contextCanceller{ + ctx: ctx, + rt: rt, + err: err, + } + } +} + +type contextCanceller struct { + ctx context.Context + rt http.RoundTripper + err error +} + +func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) { + select { + case <-b.ctx.Done(): + return nil, b.err + default: + return b.rt.RoundTrip(req) + } +} diff --git a/transport/transport_test.go b/transport/transport_test.go index 66850121..d8e75443 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -17,8 +17,10 @@ limitations under the License. package transport import ( + "context" "crypto/tls" "errors" + "fmt" "net/http" "testing" ) @@ -397,3 +399,54 @@ func TestWrappers(t *testing.T) { }) } } + +func Test_contextCanceller_RoundTrip(t *testing.T) { + tests := []struct { + name string + open bool + want bool + }{ + {name: "open context should call nested round tripper", open: true, want: true}, + {name: "closed context should return a known error", open: false, want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{} + rt := &fakeRoundTripper{Resp: &http.Response{}} + ctx := context.Background() + if !tt.open { + c, fn := context.WithCancel(ctx) + fn() + ctx = c + } + errTesting := fmt.Errorf("testing") + b := &contextCanceller{ + rt: rt, + ctx: ctx, + err: errTesting, + } + got, err := b.RoundTrip(req) + if tt.want { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if got != rt.Resp { + t.Errorf("wanted response") + } + if req != rt.Req { + t.Errorf("expect nested call") + } + } else { + if err != errTesting { + t.Errorf("unexpected error: %v", err) + } + if got != nil { + t.Errorf("wanted no response") + } + if rt.Req != nil { + t.Errorf("want no nested call") + } + } + }) + } +}