diff --git a/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/mutating/dispatcher.go b/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/mutating/dispatcher.go index cf0e76a60c5..8c75fdcf2f7 100644 --- a/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/mutating/dispatcher.go +++ b/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/mutating/dispatcher.go @@ -136,7 +136,15 @@ func (a *mutatingDispatcher) Dispatch(ctx context.Context, attr admission.Attrib if ignoreClientCallFailures { klog.Warningf("Failed calling webhook, failing open %v: %v", hook.Name, callErr) utilruntime.HandleError(callErr) - continue + + select { + case <-ctx.Done(): + // parent context is canceled or timed out, no point in continuing + return apierrors.NewTimeoutError("request did not complete within requested timeout", 0) + default: + // individual webhook timed out, but parent context did not, continue + continue + } } klog.Warningf("Failed calling webhook, failing closed %v: %v", hook.Name, err) return apierrors.NewInternalError(err) @@ -181,10 +189,29 @@ func (a *mutatingDispatcher) callAttrMutatingHook(ctx context.Context, h *v1beta utiltrace.Field{"operation", attr.GetOperation()}, utiltrace.Field{"UID", uid}) defer trace.LogIfLong(500 * time.Millisecond) - r := client.Post().Context(ctx).Body(request) + + // if the webhook has a specific timeout, wrap the context to apply it if h.TimeoutSeconds != nil { - r = r.Timeout(time.Duration(*h.TimeoutSeconds) * time.Second) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(*h.TimeoutSeconds)*time.Second) + defer cancel() } + + r := client.Post().Context(ctx).Body(request) + + // if the context has a deadline, set it as a parameter to inform the backend + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + // compute the timeout + if timeout := time.Until(deadline); timeout > 0 { + // if it's not an even number of seconds, round up to the nearest second + if truncated := timeout.Truncate(time.Second); truncated != timeout { + timeout = truncated + time.Second + } + // set the timeout + r.Timeout(timeout) + } + } + if err := r.Do().Into(response); err != nil { return false, &webhookutil.ErrCallingWebhook{WebhookName: h.Name, Reason: err} } diff --git a/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/validating/dispatcher.go b/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/validating/dispatcher.go index 6525524f5b0..3e1cb589aaf 100644 --- a/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/validating/dispatcher.go +++ b/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/validating/dispatcher.go @@ -80,6 +80,14 @@ func (d *validatingDispatcher) Dispatch(ctx context.Context, attr admission.Attr return nil } + // Check if the request has already timed out before spawning remote calls + select { + case <-ctx.Done(): + // parent context is canceled or timed out, no point in continuing + return apierrors.NewTimeoutError("request did not complete within requested timeout", 0) + default: + } + wg := sync.WaitGroup{} errCh := make(chan error, len(relevantHooks)) wg.Add(len(relevantHooks)) @@ -162,10 +170,29 @@ func (d *validatingDispatcher) callHook(ctx context.Context, h *v1beta1.Validati utiltrace.Field{"operation", attr.GetOperation()}, utiltrace.Field{"UID", uid}) defer trace.LogIfLong(500 * time.Millisecond) - r := client.Post().Context(ctx).Body(request) + + // if the webhook has a specific timeout, wrap the context to apply it if h.TimeoutSeconds != nil { - r = r.Timeout(time.Duration(*h.TimeoutSeconds) * time.Second) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(*h.TimeoutSeconds)*time.Second) + defer cancel() } + + r := client.Post().Context(ctx).Body(request) + + // if the context has a deadline, set it as a parameter to inform the backend + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + // compute the timeout + if timeout := time.Until(deadline); timeout > 0 { + // if it's not an even number of seconds, round up to the nearest second + if truncated := timeout.Truncate(time.Second); truncated != timeout { + timeout = truncated + time.Second + } + // set the timeout + r.Timeout(timeout) + } + } + if err := r.Do().Into(response); err != nil { return &webhookutil.ErrCallingWebhook{WebhookName: h.Name, Reason: err} } diff --git a/test/integration/apiserver/admissionwebhook/BUILD b/test/integration/apiserver/admissionwebhook/BUILD index 2ea7dd8ed55..c7ce021645d 100644 --- a/test/integration/apiserver/admissionwebhook/BUILD +++ b/test/integration/apiserver/admissionwebhook/BUILD @@ -7,6 +7,7 @@ go_test( "broken_webhook_test.go", "main_test.go", "reinvocation_test.go", + "timeout_test.go", ], rundir = ".", tags = [ diff --git a/test/integration/apiserver/admissionwebhook/timeout_test.go b/test/integration/apiserver/admissionwebhook/timeout_test.go new file mode 100644 index 00000000000..e37cd2becf7 --- /dev/null +++ b/test/integration/apiserver/admissionwebhook/timeout_test.go @@ -0,0 +1,458 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package admissionwebhook + +import ( + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "sort" + "strings" + "sync" + "testing" + "time" + + "k8s.io/api/admission/v1beta1" + admissionv1beta1 "k8s.io/api/admissionregistration/v1beta1" + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" + clientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + kubeapiservertesting "k8s.io/kubernetes/cmd/kube-apiserver/app/testing" + "k8s.io/kubernetes/test/integration/framework" +) + +const ( + testTimeoutClientUsername = "webhook-timeout-integration-client" +) + +// TestWebhookTimeoutWithWatchCache ensures that the admission webhook timeout policy is applied correctly with the watch cache enabled. +func TestWebhookTimeoutWithWatchCache(t *testing.T) { + testWebhookTimeout(t, true) +} + +// TestWebhookTimeoutWithoutWatchCache ensures that the admission webhook timeout policy is applied correctly without the watch cache enabled. +func TestWebhookTimeoutWithoutWatchCache(t *testing.T) { + testWebhookTimeout(t, false) +} + +type invocation struct { + path string + timeoutSeconds int +} + +// testWebhookTimeout ensures that the admission webhook timeout policy is applied correctly. +func testWebhookTimeout(t *testing.T, watchCache bool) { + type testWebhook struct { + path string + timeoutSeconds int32 + policy admissionv1beta1.FailurePolicyType + objectSelector *metav1.LabelSelector + } + + testCases := []struct { + name string + timeoutSeconds int32 + mutatingWebhooks []testWebhook + validatingWebhooks []testWebhook + expectInvocations []invocation + expectError bool + errorContains string + }{ + { + name: "minimum of request timeout or webhook timeout propagated", + timeoutSeconds: 10, + mutatingWebhooks: []testWebhook{ + {path: "/mutating/1/0s", policy: admissionv1beta1.Fail, timeoutSeconds: 20}, + {path: "/mutating/2/0s", policy: admissionv1beta1.Fail, timeoutSeconds: 5}, + }, + validatingWebhooks: []testWebhook{ + {path: "/validating/3/0s", policy: admissionv1beta1.Fail, timeoutSeconds: 20}, + {path: "/validating/4/0s", policy: admissionv1beta1.Fail, timeoutSeconds: 5}, + }, + expectInvocations: []invocation{ + {path: "/mutating/1/0s", timeoutSeconds: 10}, // from request + {path: "/mutating/2/0s", timeoutSeconds: 5}, // from webhook config + {path: "/validating/3/0s", timeoutSeconds: 10}, // from request + {path: "/validating/4/0s", timeoutSeconds: 5}, // from webhook config + }, + }, + { + name: "webhooks consume client timeout available, not webhook timeout", + timeoutSeconds: 10, + mutatingWebhooks: []testWebhook{ + {path: "/mutating/1/1s", policy: admissionv1beta1.Fail, timeoutSeconds: 20}, + {path: "/mutating/2/1s", policy: admissionv1beta1.Fail, timeoutSeconds: 5}, + {path: "/mutating/3/1s", policy: admissionv1beta1.Fail, timeoutSeconds: 20}, + }, + validatingWebhooks: []testWebhook{ + {path: "/validating/4/1s", policy: admissionv1beta1.Fail, timeoutSeconds: 5}, + {path: "/validating/5/1s", policy: admissionv1beta1.Fail, timeoutSeconds: 10}, + {path: "/validating/6/1s", policy: admissionv1beta1.Fail, timeoutSeconds: 20}, + }, + expectInvocations: []invocation{ + {path: "/mutating/1/1s", timeoutSeconds: 10}, // from request + {path: "/mutating/2/1s", timeoutSeconds: 5}, // from webhook config (less than request - 1s consumed) + {path: "/mutating/3/1s", timeoutSeconds: 8}, // from request - 2s consumed + {path: "/validating/4/1s", timeoutSeconds: 5}, // from webhook config (less than request - 3s consumed by mutating) + {path: "/validating/5/1s", timeoutSeconds: 7}, // from request - 3s consumed by mutating + {path: "/validating/6/1s", timeoutSeconds: 7}, // from request - 3s consumed by mutating + }, + }, + { + name: "timed out client requests skip later mutating webhooks (regardless of failure policy) and fail", + timeoutSeconds: 3, + mutatingWebhooks: []testWebhook{ + {path: "/mutating/1/5s", policy: admissionv1beta1.Ignore, timeoutSeconds: 4}, + {path: "/mutating/2/1s", policy: admissionv1beta1.Ignore, timeoutSeconds: 5}, + {path: "/mutating/3/1s", policy: admissionv1beta1.Ignore, timeoutSeconds: 5}, + }, + expectInvocations: []invocation{ + {path: "/mutating/1/5s", timeoutSeconds: 3}, // from request + }, + expectError: true, + errorContains: "request did not complete within requested timeout", + }, + } + + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(localhostCert) { + t.Fatal("Failed to append Cert from PEM") + } + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("Failed to build cert with error: %+v", err) + } + + recorder := &timeoutRecorder{invocations: []invocation{}} + webhookServer := httptest.NewUnstartedServer(newTimeoutWebhookHandler(recorder)) + webhookServer.TLS = &tls.Config{ + + RootCAs: roots, + Certificates: []tls.Certificate{cert}, + } + webhookServer.StartTLS() + defer webhookServer.Close() + + s := kubeapiservertesting.StartTestServerOrDie(t, kubeapiservertesting.NewDefaultTestServerOptions(), []string{ + "--disable-admission-plugins=ServiceAccount", + fmt.Sprintf("--watch-cache=%v", watchCache), + }, framework.SharedEtcd()) + defer s.TearDownFn() + + // Configure a client with a distinct user name so that it is easy to distinguish requests + // made by the client from requests made by controllers. We use this to filter out requests + // before recording them to ensure we don't accidentally mistake requests from controllers + // as requests made by the client. + clientConfig := rest.CopyConfig(s.ClientConfig) + clientConfig.Timeout = 0 // no timeout, we want to set this manually + clientConfig.Impersonate.UserName = testTimeoutClientUsername + clientConfig.Impersonate.Groups = []string{"system:masters", "system:authenticated"} + client, err := clientset.NewForConfig(clientConfig) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + _, err = client.CoreV1().Pods("default").Create(timeoutMarkerFixture) + if err != nil { + t.Fatal(err) + } + + for i, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + upCh := recorder.Reset() + ns := fmt.Sprintf("reinvoke-%d", i) + _, err = client.CoreV1().Namespaces().Create(&v1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: ns}}) + if err != nil { + t.Fatal(err) + } + + mutatingWebhooks := []admissionv1beta1.MutatingWebhook{} + for j, webhook := range tt.mutatingWebhooks { + name := fmt.Sprintf("admission.integration.test.%d.%s", j, strings.Replace(strings.TrimPrefix(webhook.path, "/"), "/", "-", -1)) + endpoint := webhookServer.URL + webhook.path + mutatingWebhooks = append(mutatingWebhooks, admissionv1beta1.MutatingWebhook{ + Name: name, + ClientConfig: admissionv1beta1.WebhookClientConfig{ + URL: &endpoint, + CABundle: localhostCert, + }, + Rules: []admissionv1beta1.RuleWithOperations{{ + Operations: []admissionv1beta1.OperationType{admissionv1beta1.OperationAll}, + Rule: admissionv1beta1.Rule{APIGroups: []string{""}, APIVersions: []string{"v1"}, Resources: []string{"pods"}}, + }}, + ObjectSelector: webhook.objectSelector, + FailurePolicy: &tt.mutatingWebhooks[j].policy, + TimeoutSeconds: &tt.mutatingWebhooks[j].timeoutSeconds, + AdmissionReviewVersions: []string{"v1beta1"}, + }) + } + mutatingCfg, err := client.AdmissionregistrationV1beta1().MutatingWebhookConfigurations().Create(&admissionv1beta1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("admission.integration.test-%d", i)}, + Webhooks: mutatingWebhooks, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + err := client.AdmissionregistrationV1beta1().MutatingWebhookConfigurations().Delete(mutatingCfg.GetName(), &metav1.DeleteOptions{}) + if err != nil { + t.Fatal(err) + } + }() + + validatingWebhooks := []admissionv1beta1.ValidatingWebhook{} + for j, webhook := range tt.validatingWebhooks { + name := fmt.Sprintf("admission.integration.test.%d.%s", j, strings.Replace(strings.TrimPrefix(webhook.path, "/"), "/", "-", -1)) + endpoint := webhookServer.URL + webhook.path + validatingWebhooks = append(validatingWebhooks, admissionv1beta1.ValidatingWebhook{ + Name: name, + ClientConfig: admissionv1beta1.WebhookClientConfig{ + URL: &endpoint, + CABundle: localhostCert, + }, + Rules: []admissionv1beta1.RuleWithOperations{{ + Operations: []admissionv1beta1.OperationType{admissionv1beta1.OperationAll}, + Rule: admissionv1beta1.Rule{APIGroups: []string{""}, APIVersions: []string{"v1"}, Resources: []string{"pods"}}, + }}, + ObjectSelector: webhook.objectSelector, + FailurePolicy: &tt.validatingWebhooks[j].policy, + TimeoutSeconds: &tt.validatingWebhooks[j].timeoutSeconds, + AdmissionReviewVersions: []string{"v1beta1"}, + }) + } + validatingCfg, err := client.AdmissionregistrationV1beta1().ValidatingWebhookConfigurations().Create(&admissionv1beta1.ValidatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{Name: fmt.Sprintf("admission.integration.test-%d", i)}, + Webhooks: validatingWebhooks, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + err := client.AdmissionregistrationV1beta1().ValidatingWebhookConfigurations().Delete(validatingCfg.GetName(), &metav1.DeleteOptions{}) + if err != nil { + t.Fatal(err) + } + }() + + // wait until new webhook is called the first time + if err := wait.PollImmediate(time.Millisecond*5, wait.ForeverTestTimeout, func() (bool, error) { + _, err = client.CoreV1().Pods("default").Patch(timeoutMarkerFixture.Name, types.JSONPatchType, []byte("[]")) + select { + case <-upCh: + return true, nil + default: + t.Logf("Waiting for webhook to become effective, getting marker object: %v", err) + return false, nil + } + }); err != nil { + t.Fatal(err) + } + + pod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{APIVersion: "v1", Kind: "Pod"}, + ObjectMeta: metav1.ObjectMeta{ + Namespace: ns, + Name: "labeled", + Labels: map[string]string{"x": "true"}, + }, + Spec: corev1.PodSpec{ + Containers: []v1.Container{{ + Name: "fake-name", + Image: "fakeimage", + }}, + }, + } + + body, err := json.Marshal(pod) + if err != nil { + t.Fatal(err) + } + + // set the timeout parameter manually so we don't actually cut off the request client-side, and wait for the server response + err = client.CoreV1().RESTClient().Post().Resource("pods").Namespace(ns).Body(body).Param("timeout", fmt.Sprintf("%ds", tt.timeoutSeconds)).Do().Error() + // _, err = testClient.CoreV1().Pods(ns).Create(pod) + + if tt.expectError { + if err == nil { + t.Fatalf("expected error but got none") + } + if tt.errorContains != "" { + if !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("expected an error saying %q, but got: %v", tt.errorContains, err) + } + } + return + } + + if err != nil { + t.Fatal(err) + } + + if tt.expectInvocations != nil { + for i, invocation := range tt.expectInvocations { + if len(recorder.invocations) < i { + t.Errorf("expected invocation of %s, got none", invocation.path) + continue + } + + if recorder.invocations[i].path != invocation.path { + t.Errorf("expected invocation of %s, got %s", invocation.path, recorder.invocations[i].path) + continue + } + if recorder.invocations[i].timeoutSeconds != invocation.timeoutSeconds { + t.Errorf("expected invocation of %s with timeout %d, got %d", invocation.path, invocation.timeoutSeconds, recorder.invocations[i].timeoutSeconds) + continue + } + } + + for _, invocation := range recorder.invocations[len(tt.expectInvocations):] { + t.Errorf("unexpected invocation of %s", invocation.path) + } + } + }) + } +} + +type timeoutRecorder struct { + mu sync.Mutex + upCh chan struct{} + upOnce sync.Once + invocations []invocation +} + +// Reset zeros out all counts and returns a channel that is closed when the first admission of the +// marker object is received. +func (i *timeoutRecorder) Reset() chan struct{} { + i.mu.Lock() + defer i.mu.Unlock() + i.invocations = []invocation{} + i.upCh = make(chan struct{}) + i.upOnce = sync.Once{} + return i.upCh +} + +func (i *timeoutRecorder) MarkerReceived() { + i.mu.Lock() + defer i.mu.Unlock() + i.upOnce.Do(func() { + close(i.upCh) + }) +} + +func (i *timeoutRecorder) RecordInvocation(call invocation) { + i.mu.Lock() + defer i.mu.Unlock() + i.invocations = append(i.invocations, call) + sort.SliceStable(i.invocations, func(a, b int) bool { + aValidating := strings.Contains(i.invocations[a].path, "validating") + bValidating := strings.Contains(i.invocations[b].path, "validating") + switch { + case aValidating && bValidating: + // sort validating by path + return strings.Compare(i.invocations[a].path, i.invocations[b].path) < 0 + case !aValidating && !bValidating: + // keep mutating in original order + return a < b + case aValidating && !bValidating: + // put validating last + return false + default: + return true + } + }) +} + +func newTimeoutWebhookHandler(recorder *timeoutRecorder) http.Handler { + allow := func(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(&v1beta1.AdmissionReview{ + Response: &v1beta1.AdmissionResponse{ + Allowed: true, + }, + }) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + data, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), 400) + } + review := v1beta1.AdmissionReview{} + if err := json.Unmarshal(data, &review); err != nil { + http.Error(w, err.Error(), 400) + } + if review.Request.UserInfo.Username != testTimeoutClientUsername { + // skip requests not originating from this integration test's client + allow(w) + return + } + + if len(review.Request.Object.Raw) == 0 { + http.Error(w, err.Error(), 400) + } + pod := &corev1.Pod{} + if err := json.Unmarshal(review.Request.Object.Raw, pod); err != nil { + http.Error(w, err.Error(), 400) + } + + // When resetting between tests, a marker object is patched until this webhook + // observes it, at which point it is considered ready. + if pod.Namespace == timeoutMarkerFixture.Namespace && pod.Name == timeoutMarkerFixture.Name { + recorder.MarkerReceived() + allow(w) + return + } + + timeout, err := time.ParseDuration(r.URL.Query().Get("timeout")) + invocation := invocation{path: r.URL.Path, timeoutSeconds: int(timeout.Round(time.Second) / time.Second)} + recorder.RecordInvocation(invocation) + + switch { + case strings.HasSuffix(r.URL.Path, "/0s"): + allow(w) + case strings.HasSuffix(r.URL.Path, "/1s"): + time.Sleep(time.Second) + allow(w) + case strings.HasSuffix(r.URL.Path, "/5s"): + time.Sleep(5 * time.Second) + allow(w) + default: + http.NotFound(w, r) + } + }) +} + +var timeoutMarkerFixture = &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + Name: "marker", + }, + Spec: corev1.PodSpec{ + Containers: []v1.Container{{ + Name: "fake-name", + Image: "fakeimage", + }}, + }, +}