diff --git a/staging/src/k8s.io/client-go/rest/BUILD b/staging/src/k8s.io/client-go/rest/BUILD index 9f00aac950e..4966e09a2b0 100644 --- a/staging/src/k8s.io/client-go/rest/BUILD +++ b/staging/src/k8s.io/client-go/rest/BUILD @@ -36,6 +36,7 @@ go_test( "//staging/src/k8s.io/client-go/kubernetes/scheme:go_default_library", "//staging/src/k8s.io/client-go/rest/watch:go_default_library", "//staging/src/k8s.io/client-go/tools/clientcmd/api:go_default_library", + "//staging/src/k8s.io/client-go/transport:go_default_library", "//staging/src/k8s.io/client-go/util/flowcontrol:go_default_library", "//staging/src/k8s.io/client-go/util/testing:go_default_library", "//vendor/github.com/google/gofuzz:go_default_library", diff --git a/staging/src/k8s.io/client-go/tools/leaderelection/example/BUILD b/staging/src/k8s.io/client-go/tools/leaderelection/example/BUILD index cb1cfbca0fc..13ff7abd262 100644 --- a/staging/src/k8s.io/client-go/tools/leaderelection/example/BUILD +++ b/staging/src/k8s.io/client-go/tools/leaderelection/example/BUILD @@ -13,6 +13,7 @@ go_library( "//staging/src/k8s.io/client-go/tools/clientcmd:go_default_library", "//staging/src/k8s.io/client-go/tools/leaderelection:go_default_library", "//staging/src/k8s.io/client-go/tools/leaderelection/resourcelock:go_default_library", + "//staging/src/k8s.io/client-go/transport:go_default_library", "//vendor/k8s.io/klog:go_default_library", ], ) diff --git a/staging/src/k8s.io/client-go/tools/leaderelection/example/main.go b/staging/src/k8s.io/client-go/tools/leaderelection/example/main.go index 91511e5b16d..ebcc0e8dad5 100644 --- a/staging/src/k8s.io/client-go/tools/leaderelection/example/main.go +++ b/staging/src/k8s.io/client-go/tools/leaderelection/example/main.go @@ -19,9 +19,11 @@ package main import ( "context" "flag" + "fmt" "log" "os" "os/signal" + "strings" "syscall" "time" @@ -31,6 +33,7 @@ import ( "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/leaderelection/resourcelock" + "k8s.io/client-go/transport" "k8s.io/klog" ) @@ -78,6 +81,10 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // use a client that will stop allowing new requests once the context ends + config.Wrap(transport.ContextCanceller(ctx, fmt.Errorf("the leader is shutting down"))) + exampleClient := kubernetes.NewForConfigOrDie(config).CoreV1() + // listen for interrupts or the Linux SIGTERM signal and cancel // our context, which the leader election code will observe and // step down @@ -116,6 +123,12 @@ func main() { }, }) + // because the context is closed, the client should report errors + _, err = exampleClient.ConfigMaps(args[1]).Get(args[2], metav1.GetOptions{}) + if err == nil || !strings.Contains(err.Error(), "the leader is shutting down") { + log.Fatalf("%s: expected to get an error when trying to make a client call: %v", id, err) + } + // we no longer hold the lease, so perform any cleanup and then // exit log.Printf("%s: done", id) diff --git a/staging/src/k8s.io/client-go/transport/transport.go b/staging/src/k8s.io/client-go/transport/transport.go index f62f8003d6a..2a145c971a3 100644 --- a/staging/src/k8s.io/client-go/transport/transport.go +++ b/staging/src/k8s.io/client-go/transport/transport.go @@ -17,6 +17,7 @@ limitations under the License. package transport import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -196,3 +197,31 @@ func Wrappers(fns ...WrapperFunc) WrapperFunc { return base } } + +// ContextCanceller prevents new requests after the provided context is finished. +// err is returned when the context is closed, allowing the caller to provide a context +// appropriate error. +func ContextCanceller(ctx context.Context, err error) WrapperFunc { + return func(rt http.RoundTripper) http.RoundTripper { + return &contextCanceller{ + ctx: ctx, + rt: rt, + err: err, + } + } +} + +type contextCanceller struct { + ctx context.Context + rt http.RoundTripper + err error +} + +func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) { + select { + case <-b.ctx.Done(): + return nil, b.err + default: + return b.rt.RoundTrip(req) + } +} diff --git a/staging/src/k8s.io/client-go/transport/transport_test.go b/staging/src/k8s.io/client-go/transport/transport_test.go index 6685012106d..d8e75443210 100644 --- a/staging/src/k8s.io/client-go/transport/transport_test.go +++ b/staging/src/k8s.io/client-go/transport/transport_test.go @@ -17,8 +17,10 @@ limitations under the License. package transport import ( + "context" "crypto/tls" "errors" + "fmt" "net/http" "testing" ) @@ -397,3 +399,54 @@ func TestWrappers(t *testing.T) { }) } } + +func Test_contextCanceller_RoundTrip(t *testing.T) { + tests := []struct { + name string + open bool + want bool + }{ + {name: "open context should call nested round tripper", open: true, want: true}, + {name: "closed context should return a known error", open: false, want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{} + rt := &fakeRoundTripper{Resp: &http.Response{}} + ctx := context.Background() + if !tt.open { + c, fn := context.WithCancel(ctx) + fn() + ctx = c + } + errTesting := fmt.Errorf("testing") + b := &contextCanceller{ + rt: rt, + ctx: ctx, + err: errTesting, + } + got, err := b.RoundTrip(req) + if tt.want { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if got != rt.Resp { + t.Errorf("wanted response") + } + if req != rt.Req { + t.Errorf("expect nested call") + } + } else { + if err != errTesting { + t.Errorf("unexpected error: %v", err) + } + if got != nil { + t.Errorf("wanted no response") + } + if rt.Req != nil { + t.Errorf("want no nested call") + } + } + }) + } +}