diff --git a/transport/round_trippers.go b/transport/round_trippers.go index 459a9376..316a5c0d 100644 --- a/transport/round_trippers.go +++ b/transport/round_trippers.go @@ -129,7 +129,7 @@ func SetAuthProxyHeaders(req *http.Request, username string, groups []string, ex } for key, values := range extra { for _, value := range values { - req.Header.Add("X-Remote-Extra-"+key, value) + req.Header.Add("X-Remote-Extra-"+headerKeyEscape(key), value) } } } @@ -246,7 +246,7 @@ func (rt *impersonatingRoundTripper) RoundTrip(req *http.Request) (*http.Respons } for k, vv := range rt.impersonate.Extra { for _, v := range vv { - req.Header.Add(ImpersonateUserExtraHeaderPrefix+k, v) + req.Header.Add(ImpersonateUserExtraHeaderPrefix+headerKeyEscape(k), v) } } @@ -422,3 +422,110 @@ func (rt *debuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e func (rt *debuggingRoundTripper) WrappedRoundTripper() http.RoundTripper { return rt.delegatedRoundTripper } + +func legalHeaderByte(b byte) bool { + return int(b) < len(legalHeaderKeyBytes) && legalHeaderKeyBytes[b] +} + +func shouldEscape(b byte) bool { + // url.PathUnescape() returns an error if any '%' is not followed by two + // hexadecimal digits, so we'll intentionally encode it. + return !legalHeaderByte(b) || b == '%' +} + +func headerKeyEscape(key string) string { + buf := strings.Builder{} + for i := 0; i < len(key); i++ { + b := key[i] + if shouldEscape(b) { + // %-encode bytes that should be escaped: + // https://tools.ietf.org/html/rfc3986#section-2.1 + fmt.Fprintf(&buf, "%%%02X", b) + continue + } + buf.WriteByte(b) + } + return buf.String() +} + +// legalHeaderKeyBytes was copied from net/http/lex.go's isTokenTable. +// See https://httpwg.github.io/specs/rfc7230.html#rule.token.separators +var legalHeaderKeyBytes = [127]bool{ + '%': true, + '!': true, + '#': true, + '$': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} diff --git a/transport/round_trippers_test.go b/transport/round_trippers_test.go index d5ffc6bd..74d3dc21 100644 --- a/transport/round_trippers_test.go +++ b/transport/round_trippers_test.go @@ -18,6 +18,7 @@ package transport import ( "net/http" + "net/url" "reflect" "strings" "testing" @@ -125,6 +126,32 @@ func TestImpersonationRoundTripper(t *testing.T) { ImpersonateUserExtraHeaderPrefix + "Second": {"B", "b"}, }, }, + { + name: "escape handling", + impersonationConfig: ImpersonationConfig{ + UserName: "user", + Extra: map[string][]string{ + "test.example.com/thing.thing": {"A", "a"}, + }, + }, + expected: map[string][]string{ + ImpersonateUserHeader: {"user"}, + ImpersonateUserExtraHeaderPrefix + `Test.example.com%2fthing.thing`: {"A", "a"}, + }, + }, + { + name: "double escape handling", + impersonationConfig: ImpersonationConfig{ + UserName: "user", + Extra: map[string][]string{ + "test.example.com/thing.thing%20another.thing": {"A", "a"}, + }, + }, + expected: map[string][]string{ + ImpersonateUserHeader: {"user"}, + ImpersonateUserExtraHeaderPrefix + `Test.example.com%2fthing.thing%2520another.thing`: {"A", "a"}, + }, + }, } for _, tc := range tcs { @@ -159,9 +186,10 @@ func TestImpersonationRoundTripper(t *testing.T) { func TestAuthProxyRoundTripper(t *testing.T) { for n, tc := range map[string]struct { - username string - groups []string - extra map[string][]string + username string + groups []string + extra map[string][]string + expectedExtra map[string][]string }{ "allfields": { username: "user", @@ -170,6 +198,34 @@ func TestAuthProxyRoundTripper(t *testing.T) { "one": {"alpha", "bravo"}, "two": {"charlie", "delta"}, }, + expectedExtra: map[string][]string{ + "one": {"alpha", "bravo"}, + "two": {"charlie", "delta"}, + }, + }, + "escaped extra": { + username: "user", + groups: []string{"groupA", "groupB"}, + extra: map[string][]string{ + "one": {"alpha", "bravo"}, + "example.com/two": {"charlie", "delta"}, + }, + expectedExtra: map[string][]string{ + "one": {"alpha", "bravo"}, + "example.com%2ftwo": {"charlie", "delta"}, + }, + }, + "double escaped extra": { + username: "user", + groups: []string{"groupA", "groupB"}, + extra: map[string][]string{ + "one": {"alpha", "bravo"}, + "example.com/two%20three": {"charlie", "delta"}, + }, + expectedExtra: map[string][]string{ + "one": {"alpha", "bravo"}, + "example.com%2ftwo%2520three": {"charlie", "delta"}, + }, }, } { rt := &testRoundTripper{} @@ -210,9 +266,64 @@ func TestAuthProxyRoundTripper(t *testing.T) { actualExtra[extraKey] = append(actualExtra[key], values...) } } - if e, a := tc.extra, actualExtra; !reflect.DeepEqual(e, a) { + if e, a := tc.expectedExtra, actualExtra; !reflect.DeepEqual(e, a) { t.Errorf("%s expected %v, got %v", n, e, a) continue } } } + +// TestHeaderEscapeRoundTrip tests to see if foo == url.PathUnescape(headerEscape(foo)) +// This behavior is important for client -> API server transmission of extra values. +func TestHeaderEscapeRoundTrip(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + key string + }{ + { + name: "alpha", + key: "alphabetical", + }, + { + name: "alphanumeric", + key: "alph4num3r1c", + }, + { + name: "percent encoded", + key: "percent%20encoded", + }, + { + name: "almost percent encoded", + key: "almost%zzpercent%xxencoded", + }, + { + name: "illegal char & percent encoding", + key: "example.com/percent%20encoded", + }, + { + name: "weird unicode stuff", + key: "example.com/ᛒᚥᛏᛖᚥᚢとロビン", + }, + { + name: "header legal chars", + key: "abc123!#$+.-_*\\^`~|'", + }, + { + name: "legal path, illegal header", + key: "@=:", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + escaped := headerKeyEscape(tc.key) + unescaped, err := url.PathUnescape(escaped) + if err != nil { + t.Fatalf("url.PathUnescape(%q) returned error: %v", escaped, err) + } + if tc.key != unescaped { + t.Errorf("url.PathUnescape(headerKeyEscape(%q)) returned %q, wanted %q", tc.key, unescaped, tc.key) + } + }) + } +}