From fe74efb1f90826b1903d2908ff9e528329bebea0 Mon Sep 17 00:00:00 2001 From: Clayton Coleman Date: Thu, 27 Dec 2018 12:29:34 -0500 Subject: [PATCH] Add transport wrapper that blocks api calls after context close The ContextCanceller transport wrapper blocks all API requests after the provided context is closed. Used with the leader election step down, a controller can ensure that new requests are not made after the client has stepped down. --- staging/src/k8s.io/client-go/rest/BUILD | 1 + .../tools/leaderelection/example/BUILD | 1 + .../tools/leaderelection/example/main.go | 13 +++++ .../k8s.io/client-go/transport/transport.go | 29 ++++++++++ .../client-go/transport/transport_test.go | 53 +++++++++++++++++++ 5 files changed, 97 insertions(+) diff --git a/staging/src/k8s.io/client-go/rest/BUILD b/staging/src/k8s.io/client-go/rest/BUILD index 9f00aac950e..4966e09a2b0 100644 --- a/staging/src/k8s.io/client-go/rest/BUILD +++ b/staging/src/k8s.io/client-go/rest/BUILD @@ -36,6 +36,7 @@ go_test( "//staging/src/k8s.io/client-go/kubernetes/scheme:go_default_library", "//staging/src/k8s.io/client-go/rest/watch:go_default_library", "//staging/src/k8s.io/client-go/tools/clientcmd/api:go_default_library", + "//staging/src/k8s.io/client-go/transport:go_default_library", "//staging/src/k8s.io/client-go/util/flowcontrol:go_default_library", "//staging/src/k8s.io/client-go/util/testing:go_default_library", "//vendor/github.com/google/gofuzz:go_default_library", diff --git a/staging/src/k8s.io/client-go/tools/leaderelection/example/BUILD b/staging/src/k8s.io/client-go/tools/leaderelection/example/BUILD index cb1cfbca0fc..13ff7abd262 100644 --- a/staging/src/k8s.io/client-go/tools/leaderelection/example/BUILD +++ b/staging/src/k8s.io/client-go/tools/leaderelection/example/BUILD @@ -13,6 +13,7 @@ go_library( "//staging/src/k8s.io/client-go/tools/clientcmd:go_default_library", "//staging/src/k8s.io/client-go/tools/leaderelection:go_default_library", "//staging/src/k8s.io/client-go/tools/leaderelection/resourcelock:go_default_library", + "//staging/src/k8s.io/client-go/transport:go_default_library", "//vendor/k8s.io/klog:go_default_library", ], ) diff --git a/staging/src/k8s.io/client-go/tools/leaderelection/example/main.go b/staging/src/k8s.io/client-go/tools/leaderelection/example/main.go index 91511e5b16d..ebcc0e8dad5 100644 --- a/staging/src/k8s.io/client-go/tools/leaderelection/example/main.go +++ b/staging/src/k8s.io/client-go/tools/leaderelection/example/main.go @@ -19,9 +19,11 @@ package main import ( "context" "flag" + "fmt" "log" "os" "os/signal" + "strings" "syscall" "time" @@ -31,6 +33,7 @@ import ( "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/leaderelection/resourcelock" + "k8s.io/client-go/transport" "k8s.io/klog" ) @@ -78,6 +81,10 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // use a client that will stop allowing new requests once the context ends + config.Wrap(transport.ContextCanceller(ctx, fmt.Errorf("the leader is shutting down"))) + exampleClient := kubernetes.NewForConfigOrDie(config).CoreV1() + // listen for interrupts or the Linux SIGTERM signal and cancel // our context, which the leader election code will observe and // step down @@ -116,6 +123,12 @@ func main() { }, }) + // because the context is closed, the client should report errors + _, err = exampleClient.ConfigMaps(args[1]).Get(args[2], metav1.GetOptions{}) + if err == nil || !strings.Contains(err.Error(), "the leader is shutting down") { + log.Fatalf("%s: expected to get an error when trying to make a client call: %v", id, err) + } + // we no longer hold the lease, so perform any cleanup and then // exit log.Printf("%s: done", id) diff --git a/staging/src/k8s.io/client-go/transport/transport.go b/staging/src/k8s.io/client-go/transport/transport.go index f62f8003d6a..2a145c971a3 100644 --- a/staging/src/k8s.io/client-go/transport/transport.go +++ b/staging/src/k8s.io/client-go/transport/transport.go @@ -17,6 +17,7 @@ limitations under the License. package transport import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -196,3 +197,31 @@ func Wrappers(fns ...WrapperFunc) WrapperFunc { return base } } + +// ContextCanceller prevents new requests after the provided context is finished. +// err is returned when the context is closed, allowing the caller to provide a context +// appropriate error. +func ContextCanceller(ctx context.Context, err error) WrapperFunc { + return func(rt http.RoundTripper) http.RoundTripper { + return &contextCanceller{ + ctx: ctx, + rt: rt, + err: err, + } + } +} + +type contextCanceller struct { + ctx context.Context + rt http.RoundTripper + err error +} + +func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) { + select { + case <-b.ctx.Done(): + return nil, b.err + default: + return b.rt.RoundTrip(req) + } +} diff --git a/staging/src/k8s.io/client-go/transport/transport_test.go b/staging/src/k8s.io/client-go/transport/transport_test.go index 6685012106d..d8e75443210 100644 --- a/staging/src/k8s.io/client-go/transport/transport_test.go +++ b/staging/src/k8s.io/client-go/transport/transport_test.go @@ -17,8 +17,10 @@ limitations under the License. package transport import ( + "context" "crypto/tls" "errors" + "fmt" "net/http" "testing" ) @@ -397,3 +399,54 @@ func TestWrappers(t *testing.T) { }) } } + +func Test_contextCanceller_RoundTrip(t *testing.T) { + tests := []struct { + name string + open bool + want bool + }{ + {name: "open context should call nested round tripper", open: true, want: true}, + {name: "closed context should return a known error", open: false, want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{} + rt := &fakeRoundTripper{Resp: &http.Response{}} + ctx := context.Background() + if !tt.open { + c, fn := context.WithCancel(ctx) + fn() + ctx = c + } + errTesting := fmt.Errorf("testing") + b := &contextCanceller{ + rt: rt, + ctx: ctx, + err: errTesting, + } + got, err := b.RoundTrip(req) + if tt.want { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if got != rt.Resp { + t.Errorf("wanted response") + } + if req != rt.Req { + t.Errorf("expect nested call") + } + } else { + if err != errTesting { + t.Errorf("unexpected error: %v", err) + } + if got != nil { + t.Errorf("wanted no response") + } + if rt.Req != nil { + t.Errorf("want no nested call") + } + } + }) + } +}