diff --git a/transport/round_trippers.go b/transport/round_trippers.go index 56df8ead..cd0a4455 100644 --- a/transport/round_trippers.go +++ b/transport/round_trippers.go @@ -23,9 +23,9 @@ import ( "time" "golang.org/x/oauth2" - "k8s.io/klog/v2" utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/klog/v2" ) // HTTPWrappersForConfig wraps a round tripper with any relevant layered @@ -68,13 +68,13 @@ func HTTPWrappersForConfig(config *Config, rt http.RoundTripper) (http.RoundTrip func DebugWrappers(rt http.RoundTripper) http.RoundTripper { switch { case bool(klog.V(9).Enabled()): - rt = newDebuggingRoundTripper(rt, debugCurlCommand, debugURLTiming, debugResponseHeaders) + rt = NewDebuggingRoundTripper(rt, DebugCurlCommand, DebugURLTiming, DebugResponseHeaders) case bool(klog.V(8).Enabled()): - rt = newDebuggingRoundTripper(rt, debugJustURL, debugRequestHeaders, debugResponseStatus, debugResponseHeaders) + rt = NewDebuggingRoundTripper(rt, DebugJustURL, DebugRequestHeaders, DebugResponseStatus, DebugResponseHeaders) case bool(klog.V(7).Enabled()): - rt = newDebuggingRoundTripper(rt, debugJustURL, debugRequestHeaders, debugResponseStatus) + rt = NewDebuggingRoundTripper(rt, DebugJustURL, DebugRequestHeaders, DebugResponseStatus) case bool(klog.V(6).Enabled()): - rt = newDebuggingRoundTripper(rt, debugURLTiming) + rt = NewDebuggingRoundTripper(rt, DebugURLTiming) } return rt @@ -353,25 +353,35 @@ func (r *requestInfo) toCurl() string { // through it based on what is configured type debuggingRoundTripper struct { delegatedRoundTripper http.RoundTripper - - levels map[debugLevel]bool + levels map[DebugLevel]bool } -type debugLevel int +// DebugLevel is used to enable debugging of certain +// HTTP requests and responses fields via the debuggingRoundTripper. +type DebugLevel int const ( - debugJustURL debugLevel = iota - debugURLTiming - debugCurlCommand - debugRequestHeaders - debugResponseStatus - debugResponseHeaders + // DebugJustURL will add to the debug output HTTP requests method and url. + DebugJustURL DebugLevel = iota + // DebugURLTiming will add to the debug output the duration of HTTP requests. + DebugURLTiming + // DebugCurlCommand will add to the debug output the curl command equivalent to the + // HTTP request. + DebugCurlCommand + // DebugRequestHeaders will add to the debug output the HTTP requests headers. + DebugRequestHeaders + // DebugResponseStatus will add to the debug output the HTTP response status. + DebugResponseStatus + // DebugResponseHeaders will add to the debug output the HTTP response headers. + DebugResponseHeaders ) -func newDebuggingRoundTripper(rt http.RoundTripper, levels ...debugLevel) *debuggingRoundTripper { +// NewDebuggingRoundTripper allows to display in the logs output debug information +// on the API requests performed by the client. +func NewDebuggingRoundTripper(rt http.RoundTripper, levels ...DebugLevel) http.RoundTripper { drt := &debuggingRoundTripper{ delegatedRoundTripper: rt, - levels: make(map[debugLevel]bool, len(levels)), + levels: make(map[DebugLevel]bool, len(levels)), } for _, v := range levels { drt.levels[v] = true @@ -418,14 +428,13 @@ func maskValue(key string, value string) string { func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { reqInfo := newRequestInfo(req) - if rt.levels[debugJustURL] { + if rt.levels[DebugJustURL] { klog.Infof("%s %s", reqInfo.RequestVerb, reqInfo.RequestURL) } - if rt.levels[debugCurlCommand] { + if rt.levels[DebugCurlCommand] { klog.Infof("%s", reqInfo.toCurl()) - } - if rt.levels[debugRequestHeaders] { + if rt.levels[DebugRequestHeaders] { klog.Info("Request Headers:") for key, values := range reqInfo.RequestHeaders { for _, value := range values { @@ -441,13 +450,13 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e reqInfo.complete(response, err) - if rt.levels[debugURLTiming] { + if rt.levels[DebugURLTiming] { klog.Infof("%s %s %s in %d milliseconds", reqInfo.RequestVerb, reqInfo.RequestURL, reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond)) } - if rt.levels[debugResponseStatus] { + if rt.levels[DebugResponseStatus] { klog.Infof("Response Status: %s in %d milliseconds", reqInfo.ResponseStatus, reqInfo.Duration.Nanoseconds()/int64(time.Millisecond)) } - if rt.levels[debugResponseHeaders] { + if rt.levels[DebugResponseHeaders] { klog.Info("Response Headers:") for key, values := range reqInfo.ResponseHeaders { for _, value := range values { diff --git a/transport/round_trippers_test.go b/transport/round_trippers_test.go index ac8de240..10f7132c 100644 --- a/transport/round_trippers_test.go +++ b/transport/round_trippers_test.go @@ -17,11 +17,15 @@ limitations under the License. package transport import ( + "bytes" + "fmt" "net/http" "net/url" "reflect" "strings" "testing" + + "k8s.io/klog/v2" ) type testRoundTripper struct { @@ -412,3 +416,102 @@ func TestHeaderEscapeRoundTrip(t *testing.T) { }) } } + +func TestDebuggingRoundTripper(t *testing.T) { + t.Parallel() + + rawURL := "https://127.0.0.1:12345/api/v1/pods?limit=500" + req := &http.Request{ + Method: http.MethodGet, + Header: map[string][]string{ + "Authorization": {"bearer secretauthtoken"}, + "X-Test-Request": {"test"}, + }, + } + res := &http.Response{ + Status: "OK", + StatusCode: http.StatusOK, + Header: map[string][]string{ + "X-Test-Response": {"test"}, + }, + } + tcs := []struct { + levels []DebugLevel + expectedOutputLines []string + }{ + { + levels: []DebugLevel{DebugJustURL}, + expectedOutputLines: []string{fmt.Sprintf("%s %s", req.Method, rawURL)}, + }, + { + levels: []DebugLevel{DebugRequestHeaders}, + expectedOutputLines: func() []string { + lines := []string{fmt.Sprintf("Request Headers:\n")} + for key, values := range req.Header { + for _, value := range values { + if key == "Authorization" { + value = "bearer " + } + lines = append(lines, fmt.Sprintf(" %s: %s\n", key, value)) + } + } + return lines + }(), + }, + { + levels: []DebugLevel{DebugResponseHeaders}, + expectedOutputLines: func() []string { + lines := []string{fmt.Sprintf("Response Headers:\n")} + for key, values := range res.Header { + for _, value := range values { + lines = append(lines, fmt.Sprintf(" %s: %s\n", key, value)) + } + } + return lines + }(), + }, + { + levels: []DebugLevel{DebugURLTiming}, + expectedOutputLines: []string{fmt.Sprintf("%s %s %s", req.Method, rawURL, res.Status)}, + }, + { + levels: []DebugLevel{DebugResponseStatus}, + expectedOutputLines: []string{fmt.Sprintf("Response Status: %s", res.Status)}, + }, + { + levels: []DebugLevel{DebugCurlCommand}, + expectedOutputLines: []string{fmt.Sprintf("curl -k -v -X")}, + }, + } + + for _, tc := range tcs { + // hijack the klog output + tmpWriteBuffer := bytes.NewBuffer(nil) + klog.SetOutput(tmpWriteBuffer) + klog.LogToStderr(false) + + // parse rawURL + parsedURL, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("url.Parse(%q) returned error: %v", rawURL, err) + } + req.URL = parsedURL + + // execute the round tripper + rt := &testRoundTripper{ + Response: res, + } + NewDebuggingRoundTripper(rt, tc.levels...).RoundTrip(req) + + // call Flush to ensure the text isn't still buffered + klog.Flush() + + // check if klog's output contains the expected lines + actual := tmpWriteBuffer.String() + for _, expected := range tc.expectedOutputLines { + if !strings.Contains(actual, expected) { + t.Errorf("%q does not contain expected output %q", actual, expected) + } + } + } +}