diff --git a/staging/src/k8s.io/client-go/rest/client.go b/staging/src/k8s.io/client-go/rest/client.go index 159caa13fab..29a254484c4 100644 --- a/staging/src/k8s.io/client-go/rest/client.go +++ b/staging/src/k8s.io/client-go/rest/client.go @@ -101,7 +101,7 @@ type RESTClient struct { // warningHandler is shared among all requests created by this client. // If not set, defaultWarningHandler is used. - warningHandler WarningHandler + warningHandler WarningHandlerWithContext // Set specific behavior of the client. If not set http.DefaultClient will be used. Client *http.Client diff --git a/staging/src/k8s.io/client-go/rest/config.go b/staging/src/k8s.io/client-go/rest/config.go index f2e813d075e..fd4324efb6b 100644 --- a/staging/src/k8s.io/client-go/rest/config.go +++ b/staging/src/k8s.io/client-go/rest/config.go @@ -129,10 +129,23 @@ type Config struct { RateLimiter flowcontrol.RateLimiter // WarningHandler handles warnings in server responses. - // If not set, the default warning handler is used. - // See documentation for SetDefaultWarningHandler() for details. + // If this and WarningHandlerWithContext are not set, the + // default warning handler is used. If both are set, + // WarningHandlerWithContext is used. + // + // See documentation for [SetDefaultWarningHandler] for details. + // + //logcheck:context // WarningHandlerWithContext should be used instead of WarningHandler in code which supports contextual logging. WarningHandler WarningHandler + // WarningHandlerWithContext handles warnings in server responses. + // If this and WarningHandler are not set, the + // default warning handler is used. If both are set, + // WarningHandlerWithContext is used. + // + // See documentation for [SetDefaultWarningHandler] for details. + WarningHandlerWithContext WarningHandlerWithContext + // The maximum length of time to wait before giving up on a server request. A value of zero means no timeout. Timeout time.Duration @@ -381,12 +394,27 @@ func RESTClientForConfigAndClient(config *Config, httpClient *http.Client) (*RES } restClient, err := NewRESTClient(baseURL, versionedAPIPath, clientContent, rateLimiter, httpClient) - if err == nil && config.WarningHandler != nil { - restClient.warningHandler = config.WarningHandler - } + maybeSetWarningHandler(restClient, config.WarningHandler, config.WarningHandlerWithContext) return restClient, err } +// maybeSetWarningHandler sets the handlerWithContext if non-nil, +// otherwise the handler with a wrapper if non-nil, +// and does nothing if both are nil. +// +// May be called for a nil client. +func maybeSetWarningHandler(c *RESTClient, handler WarningHandler, handlerWithContext WarningHandlerWithContext) { + if c == nil { + return + } + switch { + case handlerWithContext != nil: + c.warningHandler = handlerWithContext + case handler != nil: + c.warningHandler = warningLoggerNopContext{l: handler} + } +} + // UnversionedRESTClientFor is the same as RESTClientFor, except that it allows // the config.Version to be empty. func UnversionedRESTClientFor(config *Config) (*RESTClient, error) { @@ -448,9 +476,7 @@ func UnversionedRESTClientForConfigAndClient(config *Config, httpClient *http.Cl } restClient, err := NewRESTClient(baseURL, versionedAPIPath, clientContent, rateLimiter, httpClient) - if err == nil && config.WarningHandler != nil { - restClient.warningHandler = config.WarningHandler - } + maybeSetWarningHandler(restClient, config.WarningHandler, config.WarningHandlerWithContext) return restClient, err } @@ -616,15 +642,16 @@ func AnonymousClientConfig(config *Config) *Config { CAData: config.TLSClientConfig.CAData, NextProtos: config.TLSClientConfig.NextProtos, }, - RateLimiter: config.RateLimiter, - WarningHandler: config.WarningHandler, - UserAgent: config.UserAgent, - DisableCompression: config.DisableCompression, - QPS: config.QPS, - Burst: config.Burst, - Timeout: config.Timeout, - Dial: config.Dial, - Proxy: config.Proxy, + RateLimiter: config.RateLimiter, + WarningHandler: config.WarningHandler, + WarningHandlerWithContext: config.WarningHandlerWithContext, + UserAgent: config.UserAgent, + DisableCompression: config.DisableCompression, + QPS: config.QPS, + Burst: config.Burst, + Timeout: config.Timeout, + Dial: config.Dial, + Proxy: config.Proxy, } } @@ -658,17 +685,18 @@ func CopyConfig(config *Config) *Config { CAData: config.TLSClientConfig.CAData, NextProtos: config.TLSClientConfig.NextProtos, }, - UserAgent: config.UserAgent, - DisableCompression: config.DisableCompression, - Transport: config.Transport, - WrapTransport: config.WrapTransport, - QPS: config.QPS, - Burst: config.Burst, - RateLimiter: config.RateLimiter, - WarningHandler: config.WarningHandler, - Timeout: config.Timeout, - Dial: config.Dial, - Proxy: config.Proxy, + UserAgent: config.UserAgent, + DisableCompression: config.DisableCompression, + Transport: config.Transport, + WrapTransport: config.WrapTransport, + QPS: config.QPS, + Burst: config.Burst, + RateLimiter: config.RateLimiter, + WarningHandler: config.WarningHandler, + WarningHandlerWithContext: config.WarningHandlerWithContext, + Timeout: config.Timeout, + Dial: config.Dial, + Proxy: config.Proxy, } if config.ExecProvider != nil && config.ExecProvider.Config != nil { c.ExecProvider.Config = config.ExecProvider.Config.DeepCopyObject() diff --git a/staging/src/k8s.io/client-go/rest/config_test.go b/staging/src/k8s.io/client-go/rest/config_test.go index 4fc74f545a1..6475813fc3f 100644 --- a/staging/src/k8s.io/client-go/rest/config_test.go +++ b/staging/src/k8s.io/client-go/rest/config_test.go @@ -41,6 +41,7 @@ import ( "github.com/google/go-cmp/cmp" fuzz "github.com/google/gofuzz" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIsConfigTransportTLS(t *testing.T) { @@ -266,6 +267,19 @@ type fakeWarningHandler struct{} func (f fakeWarningHandler) HandleWarningHeader(code int, agent string, message string) {} +type fakeWarningHandlerWithLogging struct { + messages []string +} + +func (f *fakeWarningHandlerWithLogging) HandleWarningHeader(code int, agent string, message string) { + f.messages = append(f.messages, message) +} + +type fakeWarningHandlerWithContext struct{} + +func (f fakeWarningHandlerWithContext) HandleWarningHeaderWithContext(ctx context.Context, code int, agent string, message string) { +} + type fakeNegotiatedSerializer struct{} func (n *fakeNegotiatedSerializer) SupportedMediaTypes() []runtime.SerializerInfo { @@ -330,6 +344,9 @@ func TestAnonymousAuthConfig(t *testing.T) { func(h *WarningHandler, f fuzz.Continue) { *h = &fakeWarningHandler{} }, + func(h *WarningHandlerWithContext, f fuzz.Continue) { + *h = &fakeWarningHandlerWithContext{} + }, // Authentication does not require fuzzer func(r *AuthProviderConfigPersister, f fuzz.Continue) {}, func(r *clientcmdapi.AuthProviderConfig, f fuzz.Continue) { @@ -428,6 +445,9 @@ func TestCopyConfig(t *testing.T) { func(h *WarningHandler, f fuzz.Continue) { *h = &fakeWarningHandler{} }, + func(h *WarningHandlerWithContext, f fuzz.Continue) { + *h = &fakeWarningHandlerWithContext{} + }, func(r *AuthProviderConfigPersister, f fuzz.Continue) { *r = fakeAuthProviderConfigPersister{} }, @@ -619,25 +639,69 @@ func TestConfigSprint(t *testing.T) { KeyData: []byte("fake key"), NextProtos: []string{"h2", "http/1.1"}, }, - UserAgent: "gobot", - Transport: &fakeRoundTripper{}, - WrapTransport: fakeWrapperFunc, - QPS: 1, - Burst: 2, - RateLimiter: &fakeLimiter{}, - WarningHandler: fakeWarningHandler{}, - Timeout: 3 * time.Second, - Dial: fakeDialFunc, - Proxy: fakeProxyFunc, + UserAgent: "gobot", + Transport: &fakeRoundTripper{}, + WrapTransport: fakeWrapperFunc, + QPS: 1, + Burst: 2, + RateLimiter: &fakeLimiter{}, + WarningHandler: fakeWarningHandler{}, + WarningHandlerWithContext: fakeWarningHandlerWithContext{}, + Timeout: 3 * time.Second, + Dial: fakeDialFunc, + Proxy: fakeProxyFunc, } want := fmt.Sprintf( - `&rest.Config{Host:"localhost:8080", APIPath:"v1", ContentConfig:rest.ContentConfig{AcceptContentTypes:"application/json", ContentType:"application/json", GroupVersion:(*schema.GroupVersion)(nil), NegotiatedSerializer:runtime.NegotiatedSerializer(nil)}, Username:"gopher", Password:"--- REDACTED ---", BearerToken:"--- REDACTED ---", BearerTokenFile:"", Impersonate:rest.ImpersonationConfig{UserName:"gopher2", UID:"uid123", Groups:[]string(nil), Extra:map[string][]string(nil)}, AuthProvider:api.AuthProviderConfig{Name: "gopher", Config: map[string]string{--- REDACTED ---}}, AuthConfigPersister:rest.AuthProviderConfigPersister(--- REDACTED ---), ExecProvider:api.ExecConfig{Command: "sudo", Args: []string{"--- REDACTED ---"}, Env: []ExecEnvVar{--- REDACTED ---}, APIVersion: "", ProvideClusterInfo: true, Config: runtime.Object(--- REDACTED ---), StdinUnavailable: false}, TLSClientConfig:rest.sanitizedTLSClientConfig{Insecure:false, ServerName:"", CertFile:"a.crt", KeyFile:"a.key", CAFile:"", CertData:[]uint8{0x2d, 0x2d, 0x2d, 0x20, 0x54, 0x52, 0x55, 0x4e, 0x43, 0x41, 0x54, 0x45, 0x44, 0x20, 0x2d, 0x2d, 0x2d}, KeyData:[]uint8{0x2d, 0x2d, 0x2d, 0x20, 0x52, 0x45, 0x44, 0x41, 0x43, 0x54, 0x45, 0x44, 0x20, 0x2d, 0x2d, 0x2d}, CAData:[]uint8(nil), NextProtos:[]string{"h2", "http/1.1"}}, UserAgent:"gobot", DisableCompression:false, Transport:(*rest.fakeRoundTripper)(%p), WrapTransport:(transport.WrapperFunc)(%p), QPS:1, Burst:2, RateLimiter:(*rest.fakeLimiter)(%p), WarningHandler:rest.fakeWarningHandler{}, Timeout:3000000000, Dial:(func(context.Context, string, string) (net.Conn, error))(%p), Proxy:(func(*http.Request) (*url.URL, error))(%p)}`, + `&rest.Config{Host:"localhost:8080", APIPath:"v1", ContentConfig:rest.ContentConfig{AcceptContentTypes:"application/json", ContentType:"application/json", GroupVersion:(*schema.GroupVersion)(nil), NegotiatedSerializer:runtime.NegotiatedSerializer(nil)}, Username:"gopher", Password:"--- REDACTED ---", BearerToken:"--- REDACTED ---", BearerTokenFile:"", Impersonate:rest.ImpersonationConfig{UserName:"gopher2", UID:"uid123", Groups:[]string(nil), Extra:map[string][]string(nil)}, AuthProvider:api.AuthProviderConfig{Name: "gopher", Config: map[string]string{--- REDACTED ---}}, AuthConfigPersister:rest.AuthProviderConfigPersister(--- REDACTED ---), ExecProvider:api.ExecConfig{Command: "sudo", Args: []string{"--- REDACTED ---"}, Env: []ExecEnvVar{--- REDACTED ---}, APIVersion: "", ProvideClusterInfo: true, Config: runtime.Object(--- REDACTED ---), StdinUnavailable: false}, TLSClientConfig:rest.sanitizedTLSClientConfig{Insecure:false, ServerName:"", CertFile:"a.crt", KeyFile:"a.key", CAFile:"", CertData:[]uint8{0x2d, 0x2d, 0x2d, 0x20, 0x54, 0x52, 0x55, 0x4e, 0x43, 0x41, 0x54, 0x45, 0x44, 0x20, 0x2d, 0x2d, 0x2d}, KeyData:[]uint8{0x2d, 0x2d, 0x2d, 0x20, 0x52, 0x45, 0x44, 0x41, 0x43, 0x54, 0x45, 0x44, 0x20, 0x2d, 0x2d, 0x2d}, CAData:[]uint8(nil), NextProtos:[]string{"h2", "http/1.1"}}, UserAgent:"gobot", DisableCompression:false, Transport:(*rest.fakeRoundTripper)(%p), WrapTransport:(transport.WrapperFunc)(%p), QPS:1, Burst:2, RateLimiter:(*rest.fakeLimiter)(%p), WarningHandler:rest.fakeWarningHandler{}, WarningHandlerWithContext:rest.fakeWarningHandlerWithContext{}, Timeout:3000000000, Dial:(func(context.Context, string, string) (net.Conn, error))(%p), Proxy:(func(*http.Request) (*url.URL, error))(%p)}`, c.Transport, fakeWrapperFunc, c.RateLimiter, fakeDialFunc, fakeProxyFunc, ) for _, f := range []string{"%s", "%v", "%+v", "%#v"} { if got := fmt.Sprintf(f, c); want != got { - t.Errorf("fmt.Sprintf(%q, c)\ngot: %q\nwant: %q", f, got, want) + t.Errorf("fmt.Sprintf(%q, c)\ngot: %q\nwant: %q\ndiff: %s", f, got, want, cmp.Diff(want, got)) } } } + +func TestConfigWarningHandler(t *testing.T) { + config := &Config{} + config.GroupVersion = &schema.GroupVersion{} + config.NegotiatedSerializer = &fakeNegotiatedSerializer{} + handlerNoContext := &fakeWarningHandler{} + handlerWithContext := &fakeWarningHandlerWithContext{} + + t.Run("none", func(t *testing.T) { + client, err := RESTClientForConfigAndClient(config, nil) + require.NoError(t, err) + assert.Nil(t, client.warningHandler) + }) + + t.Run("no-context", func(t *testing.T) { + config := CopyConfig(config) + handler := &fakeWarningHandlerWithLogging{} + config.WarningHandler = handler + client, err := RESTClientForConfigAndClient(config, nil) + require.NoError(t, err) + client.warningHandler.HandleWarningHeaderWithContext(context.Background(), 0, "", "message") + assert.Equal(t, []string{"message"}, handler.messages) + + }) + + t.Run("with-context", func(t *testing.T) { + config := CopyConfig(config) + config.WarningHandlerWithContext = handlerWithContext + client, err := RESTClientForConfigAndClient(config, nil) + require.NoError(t, err) + assert.Equal(t, handlerWithContext, client.warningHandler) + }) + + t.Run("both", func(t *testing.T) { + config := CopyConfig(config) + config.WarningHandler = handlerNoContext + config.WarningHandlerWithContext = handlerWithContext + client, err := RESTClientForConfigAndClient(config, nil) + require.NoError(t, err) + assert.NotNil(t, client.warningHandler) + assert.Equal(t, handlerWithContext, client.warningHandler) + }) +} diff --git a/staging/src/k8s.io/client-go/rest/exec_test.go b/staging/src/k8s.io/client-go/rest/exec_test.go index 5469c6d037b..6ab7bcc9730 100644 --- a/staging/src/k8s.io/client-go/rest/exec_test.go +++ b/staging/src/k8s.io/client-go/rest/exec_test.go @@ -242,6 +242,9 @@ func TestConfigToExecClusterRoundtrip(t *testing.T) { func(h *WarningHandler, f fuzz.Continue) { *h = &fakeWarningHandler{} }, + func(h *WarningHandlerWithContext, f fuzz.Continue) { + *h = &fakeWarningHandlerWithContext{} + }, // Authentication does not require fuzzer func(r *AuthProviderConfigPersister, f fuzz.Continue) {}, func(r *clientcmdapi.AuthProviderConfig, f fuzz.Continue) { @@ -289,6 +292,7 @@ func TestConfigToExecClusterRoundtrip(t *testing.T) { expected.Burst = 0 expected.RateLimiter = nil expected.WarningHandler = nil + expected.WarningHandlerWithContext = nil expected.Timeout = 0 expected.Dial = nil diff --git a/staging/src/k8s.io/client-go/rest/request.go b/staging/src/k8s.io/client-go/rest/request.go index 0ec90ad188b..b10db0ad8f3 100644 --- a/staging/src/k8s.io/client-go/rest/request.go +++ b/staging/src/k8s.io/client-go/rest/request.go @@ -103,7 +103,7 @@ type Request struct { contentConfig ClientContentConfig contentTypeNotSet bool - warningHandler WarningHandler + warningHandler WarningHandlerWithContext rateLimiter flowcontrol.RateLimiter backoff BackoffManager @@ -271,8 +271,21 @@ func (r *Request) BackOff(manager BackoffManager) *Request { } // WarningHandler sets the handler this client uses when warning headers are encountered. -// If set to nil, this client will use the default warning handler (see SetDefaultWarningHandler). +// If set to nil, this client will use the default warning handler (see [SetDefaultWarningHandler]). +// +//logcheck:context // WarningHandlerWithContext should be used instead of WarningHandler in code which supports contextual logging. func (r *Request) WarningHandler(handler WarningHandler) *Request { + if handler == nil { + r.warningHandler = nil + return r + } + r.warningHandler = warningLoggerNopContext{l: handler} + return r +} + +// WarningHandlerWithContext sets the handler this client uses when warning headers are encountered. +// If set to nil, this client will use the default warning handler (see [SetDefaultWarningHandlerWithContext]). +func (r *Request) WarningHandlerWithContext(handler WarningHandlerWithContext) *Request { r.warningHandler = handler return r } @@ -776,7 +789,7 @@ func (r *Request) watchInternal(ctx context.Context) (watch.Interface, runtime.D resp, err := client.Do(req) retry.After(ctx, r, resp, err) if err == nil && resp.StatusCode == http.StatusOK { - return r.newStreamWatcher(resp) + return r.newStreamWatcher(ctx, resp) } done, transformErr := func() (bool, error) { @@ -969,7 +982,7 @@ func (r *Request) handleWatchList(ctx context.Context, w watch.Interface, negoti } } -func (r *Request) newStreamWatcher(resp *http.Response) (watch.Interface, runtime.Decoder, error) { +func (r *Request) newStreamWatcher(ctx context.Context, resp *http.Response) (watch.Interface, runtime.Decoder, error) { contentType := resp.Header.Get("Content-Type") mediaType, params, err := mime.ParseMediaType(contentType) if err != nil { @@ -980,7 +993,7 @@ func (r *Request) newStreamWatcher(resp *http.Response) (watch.Interface, runtim return nil, nil, err } - handleWarnings(resp.Header, r.warningHandler) + handleWarnings(ctx, resp.Header, r.warningHandler) frameReader := framer.NewFrameReader(resp.Body) watchEventDecoder := streaming.NewDecoder(frameReader, streamingSerializer) @@ -1067,7 +1080,7 @@ func (r *Request) Stream(ctx context.Context) (io.ReadCloser, error) { switch { case (resp.StatusCode >= 200) && (resp.StatusCode < 300): - handleWarnings(resp.Header, r.warningHandler) + handleWarnings(ctx, resp.Header, r.warningHandler) return resp.Body, nil default: @@ -1365,7 +1378,7 @@ func (r *Request) transformResponse(ctx context.Context, resp *http.Response, re body: body, contentType: contentType, statusCode: resp.StatusCode, - warnings: handleWarnings(resp.Header, r.warningHandler), + warnings: handleWarnings(ctx, resp.Header, r.warningHandler), } } } @@ -1384,7 +1397,7 @@ func (r *Request) transformResponse(ctx context.Context, resp *http.Response, re statusCode: resp.StatusCode, decoder: decoder, err: err, - warnings: handleWarnings(resp.Header, r.warningHandler), + warnings: handleWarnings(ctx, resp.Header, r.warningHandler), } } @@ -1393,7 +1406,7 @@ func (r *Request) transformResponse(ctx context.Context, resp *http.Response, re contentType: contentType, statusCode: resp.StatusCode, decoder: decoder, - warnings: handleWarnings(resp.Header, r.warningHandler), + warnings: handleWarnings(ctx, resp.Header, r.warningHandler), } } diff --git a/staging/src/k8s.io/client-go/rest/request_test.go b/staging/src/k8s.io/client-go/rest/request_test.go index 186c5a35b9f..013a22816ec 100644 --- a/staging/src/k8s.io/client-go/rest/request_test.go +++ b/staging/src/k8s.io/client-go/rest/request_test.go @@ -4066,15 +4066,24 @@ func TestRequestLogging(t *testing.T) { testcases := map[string]struct { v int body any + response *http.Response expectedOutput string }{ "no-output": { v: 7, body: []byte("ping"), + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("pong")), + }, }, "output": { v: 8, body: []byte("ping"), + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("pong")), + }, expectedOutput: `] "Request Body" logger="TestLogger" body="ping" ] "Response Body" logger="TestLogger" body="pong" `, @@ -4082,6 +4091,10 @@ func TestRequestLogging(t *testing.T) { "io-reader": { v: 8, body: strings.NewReader("ping"), + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("pong")), + }, // Cannot log the request body! expectedOutput: `] "Response Body" logger="TestLogger" body="pong" `, @@ -4089,10 +4102,34 @@ func TestRequestLogging(t *testing.T) { "truncate": { v: 8, body: []byte(strings.Repeat("a", 2000)), + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("pong")), + }, expectedOutput: fmt.Sprintf(`] "Request Body" logger="TestLogger" body="%s [truncated 976 chars]" ] "Response Body" logger="TestLogger" body="pong" `, strings.Repeat("a", 1024)), }, + "warnings": { + v: 8, + body: []byte("ping"), + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Warning": []string{ + `299 request-test "warning 1"`, + `299 request-test-2 "warning 2"`, + `300 request-test-3 "ignore code 300"`, + }, + }, + Body: io.NopCloser(strings.NewReader("pong")), + }, + expectedOutput: `] "Request Body" logger="TestLogger" body="ping" +] "Response Body" logger="TestLogger" body="pong" +warnings.go] "Warning: warning 1" logger="TestLogger" +warnings.go] "Warning: warning 2" logger="TestLogger" +`, + }, } for name, tc := range testcases { @@ -4106,12 +4143,10 @@ func TestRequestLogging(t *testing.T) { var fs flag.FlagSet klog.InitFlags(&fs) require.NoError(t, fs.Set("v", fmt.Sprintf("%d", tc.v)), "set verbosity") + require.NoError(t, fs.Set("one_output", "true"), "set one_output") client := clientForFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("pong")), - }, nil + return tc.response, nil }) req := NewRequestWithClient(nil, "", defaultContentConfig(), client). @@ -4128,11 +4163,49 @@ func TestRequestLogging(t *testing.T) { // Compare log output: // - strip date/time/pid from each line (fixed length header) // - replace with the actual call location + // - strip line number from warnings.go (might change) state.Restore() expectedOutput := strings.ReplaceAll(tc.expectedOutput, "", fmt.Sprintf("%s:%d", path.Base(file), line+1)) actualOutput := buffer.String() actualOutput = regexp.MustCompile(`(?m)^.{30}`).ReplaceAllString(actualOutput, "") + actualOutput = regexp.MustCompile(`(?m)^warnings\.go:\d+`).ReplaceAllString(actualOutput, "warnings.go") assert.Equal(t, expectedOutput, actualOutput) }) } } + +func TestRequestWarningHandler(t *testing.T) { + t.Run("no-context", func(t *testing.T) { + request := &Request{} + handler := &fakeWarningHandlerWithLogging{} + //nolint:logcheck + assert.Equal(t, request, request.WarningHandler(handler)) + assert.NotNil(t, request.warningHandler) + request.warningHandler.HandleWarningHeaderWithContext(context.Background(), 0, "", "message") + assert.Equal(t, []string{"message"}, handler.messages) + }) + + t.Run("with-context", func(t *testing.T) { + request := &Request{} + handler := &fakeWarningHandlerWithContext{} + assert.Equal(t, request, request.WarningHandlerWithContext(handler)) + assert.Equal(t, request.warningHandler, handler) + }) + + t.Run("nil-no-context", func(t *testing.T) { + request := &Request{ + warningHandler: &fakeWarningHandlerWithContext{}, + } + //nolint:logcheck + assert.Equal(t, request, request.WarningHandler(nil)) + assert.Nil(t, request.warningHandler) + }) + + t.Run("nil-with-context", func(t *testing.T) { + request := &Request{ + warningHandler: &fakeWarningHandlerWithContext{}, + } + assert.Equal(t, request, request.WarningHandlerWithContext(nil)) + assert.Nil(t, request.warningHandler) + }) +} diff --git a/staging/src/k8s.io/client-go/rest/warnings.go b/staging/src/k8s.io/client-go/rest/warnings.go index ad493659f22..713b2d64d64 100644 --- a/staging/src/k8s.io/client-go/rest/warnings.go +++ b/staging/src/k8s.io/client-go/rest/warnings.go @@ -17,6 +17,7 @@ limitations under the License. package rest import ( + "context" "fmt" "io" "net/http" @@ -33,8 +34,15 @@ type WarningHandler interface { HandleWarningHeader(code int, agent string, text string) } +// WarningHandlerWithContext is an interface for handling warning headers with +// support for contextual logging. +type WarningHandlerWithContext interface { + // HandleWarningHeaderWithContext is called with the warn code, agent, and text when a warning header is countered. + HandleWarningHeaderWithContext(ctx context.Context, code int, agent string, text string) +} + var ( - defaultWarningHandler WarningHandler = WarningLogger{} + defaultWarningHandler WarningHandlerWithContext = WarningLogger{} defaultWarningHandlerLock sync.RWMutex ) @@ -43,33 +51,68 @@ var ( // - NoWarnings suppresses warnings. // - WarningLogger logs warnings. // - NewWarningWriter() outputs warnings to the provided writer. +// +// logcheck:context // SetDefaultWarningHandlerWithContext should be used instead of SetDefaultWarningHandler in code which supports contextual logging. func SetDefaultWarningHandler(l WarningHandler) { + if l == nil { + SetDefaultWarningHandlerWithContext(nil) + return + } + SetDefaultWarningHandlerWithContext(warningLoggerNopContext{l: l}) +} + +// SetDefaultWarningHandlerWithContext is a variant of [SetDefaultWarningHandler] which supports contextual logging. +func SetDefaultWarningHandlerWithContext(l WarningHandlerWithContext) { defaultWarningHandlerLock.Lock() defer defaultWarningHandlerLock.Unlock() defaultWarningHandler = l } -func getDefaultWarningHandler() WarningHandler { + +func getDefaultWarningHandler() WarningHandlerWithContext { defaultWarningHandlerLock.RLock() defer defaultWarningHandlerLock.RUnlock() l := defaultWarningHandler return l } -// NoWarnings is an implementation of WarningHandler that suppresses warnings. +type warningLoggerNopContext struct { + l WarningHandler +} + +func (w warningLoggerNopContext) HandleWarningHeaderWithContext(_ context.Context, code int, agent string, message string) { + w.l.HandleWarningHeader(code, agent, message) +} + +// NoWarnings is an implementation of [WarningHandler] and [WarningHandlerWithContext] that suppresses warnings. type NoWarnings struct{} func (NoWarnings) HandleWarningHeader(code int, agent string, message string) {} +func (NoWarnings) HandleWarningHeaderWithContext(ctx context.Context, code int, agent string, message string) { +} -// WarningLogger is an implementation of WarningHandler that logs code 299 warnings +var _ WarningHandler = NoWarnings{} +var _ WarningHandlerWithContext = NoWarnings{} + +// WarningLogger is an implementation of [WarningHandler] and [WarningHandlerWithContext] that logs code 299 warnings type WarningLogger struct{} func (WarningLogger) HandleWarningHeader(code int, agent string, message string) { if code != 299 || len(message) == 0 { return } - klog.Warning(message) + klog.Background().Info("Warning: " + message) } +func (WarningLogger) HandleWarningHeaderWithContext(ctx context.Context, code int, agent string, message string) { + if code != 299 || len(message) == 0 { + return + } + klog.FromContext(ctx).Info("Warning: " + message) +} + +var _ WarningHandler = WarningLogger{} +var _ WarningHandlerWithContext = WarningLogger{} + type warningWriter struct { // out is the writer to output warnings to out io.Writer @@ -134,14 +177,14 @@ func (w *warningWriter) WarningCount() int { return w.writtenCount } -func handleWarnings(headers http.Header, handler WarningHandler) []net.WarningHeader { +func handleWarnings(ctx context.Context, headers http.Header, handler WarningHandlerWithContext) []net.WarningHeader { if handler == nil { handler = getDefaultWarningHandler() } warnings, _ := net.ParseWarningHeaders(headers["Warning"]) for _, warning := range warnings { - handler.HandleWarningHeader(warning.Code, warning.Agent, warning.Text) + handler.HandleWarningHeaderWithContext(ctx, warning.Code, warning.Agent, warning.Text) } return warnings } diff --git a/staging/src/k8s.io/client-go/rest/warnings_test.go b/staging/src/k8s.io/client-go/rest/warnings_test.go new file mode 100644 index 00000000000..d74964310db --- /dev/null +++ b/staging/src/k8s.io/client-go/rest/warnings_test.go @@ -0,0 +1,57 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package rest + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultWarningHandler(t *testing.T) { + t.Run("default", func(t *testing.T) { + assert.IsType(t, WarningHandlerWithContext(WarningLogger{}), getDefaultWarningHandler()) + }) + + deferRestore := func(t *testing.T) { + handler := getDefaultWarningHandler() + t.Cleanup(func() { + SetDefaultWarningHandlerWithContext(handler) + }) + } + + t.Run("no-context", func(t *testing.T) { + deferRestore(t) + handler := &fakeWarningHandlerWithLogging{} + //nolint:logcheck + SetDefaultWarningHandler(handler) + getDefaultWarningHandler().HandleWarningHeaderWithContext(context.Background(), 0, "", "message") + assert.Equal(t, []string{"message"}, handler.messages) + SetDefaultWarningHandler(nil) + assert.Nil(t, getDefaultWarningHandler()) + }) + + t.Run("with-context", func(t *testing.T) { + deferRestore(t) + handler := &fakeWarningHandlerWithContext{} + SetDefaultWarningHandlerWithContext(handler) + assert.Equal(t, handler, getDefaultWarningHandler()) + SetDefaultWarningHandlerWithContext(nil) + assert.Nil(t, getDefaultWarningHandler()) + }) +}