Merge pull request #114190 from tkashem/cors-refactor

refactor CORS handler
This commit is contained in:
Kubernetes Prow Robot 2022-12-16 13:08:09 -08:00 committed by GitHub
commit 0f7409a230
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 166 additions and 89 deletions

View File

@ -38,30 +38,38 @@ func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMetho
return handler return handler
} }
allowedOriginPatternsREs := allowedOriginRegexps(allowedOriginPatterns) allowedOriginPatternsREs := allowedOriginRegexps(allowedOriginPatterns)
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if origin != "" {
allowed := false
for _, re := range allowedOriginPatternsREs {
if allowed = re.MatchString(origin); allowed {
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
// Set defaults for methods and headers if nothing was passed // Set defaults for methods and headers if nothing was passed
if allowedMethods == nil { if allowedMethods == nil {
allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"} allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"}
} }
allowMethodsResponseHeader := strings.Join(allowedMethods, ", ")
if allowedHeaders == nil { if allowedHeaders == nil {
allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"}
} }
allowHeadersResponseHeader := strings.Join(allowedHeaders, ", ")
if exposedHeaders == nil { if exposedHeaders == nil {
exposedHeaders = []string{"Date"} exposedHeaders = []string{"Date"}
} }
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) exposeHeadersResponseHeader := strings.Join(exposedHeaders, ", ")
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", "))
w.Header().Set("Access-Control-Expose-Headers", strings.Join(exposedHeaders, ", ")) return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if origin == "" {
handler.ServeHTTP(w, req)
return
}
if !isOriginAllowed(origin, allowedOriginPatternsREs) {
handler.ServeHTTP(w, req)
return
}
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", allowMethodsResponseHeader)
w.Header().Set("Access-Control-Allow-Headers", allowHeadersResponseHeader)
w.Header().Set("Access-Control-Expose-Headers", exposeHeadersResponseHeader)
w.Header().Set("Access-Control-Allow-Credentials", allowCredentials) w.Header().Set("Access-Control-Allow-Credentials", allowCredentials)
// Stop here if its a preflight OPTIONS request // Stop here if its a preflight OPTIONS request
@ -69,13 +77,37 @@ func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMetho
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
return return
} }
}
}
// Dispatch to the next handler // Dispatch to the next handler
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
}) })
} }
// isOriginAllowed returns true if the given origin header in the
// request is allowed CORS.
//
// From https://www.rfc-editor.org/rfc/rfc6454#page-13
//
// a) The origin header can contain host and/or port
// serialized-origin = scheme "://" host [ ":" port ]
//
// b) In some cases, a number of origins contribute to causing the user
// agents to issue an HTTP request. In those cases, the user agent MAY
// list all the origins in the Origin header field. For example, if the
// HTTP request was initially issued by one origin but then later
// redirected by another origin, the user agent MAY inform the server
// that two origins were involved in causing the user agent to issue the
// request
// origin-list = serialized-origin *( SP serialized-origin )
func isOriginAllowed(originHeader string, allowedOriginPatternsREs []*regexp.Regexp) bool {
for _, re := range allowedOriginPatternsREs {
if re.MatchString(originHeader) {
return true
}
}
return false
}
func allowedOriginRegexps(allowedOrigins []string) []*regexp.Regexp { func allowedOriginRegexps(allowedOrigins []string) []*regexp.Regexp {
res, err := compileRegexps(allowedOrigins) res, err := compileRegexps(allowedOrigins)
if err != nil { if err != nil {

View File

@ -17,6 +17,7 @@ limitations under the License.
package filters package filters
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -25,22 +26,60 @@ import (
) )
func TestCORSAllowedOrigins(t *testing.T) { func TestCORSAllowedOrigins(t *testing.T) {
table := []struct { tests := []struct {
name string
allowedOrigins []string allowedOrigins []string
origin string origins []string
allowed bool allowed bool
}{ }{
{[]string{}, "example.com", false}, {
{[]string{"example.com"}, "example.com", true}, name: "allowed origins list is empty",
{[]string{"example.com"}, "not-allowed.com", false}, allowedOrigins: []string{},
{[]string{"not-matching.com", "example.com"}, "example.com", true}, origins: []string{"example.com"},
{[]string{".*"}, "example.com", true}, allowed: false,
},
{
name: "origin request header not set",
allowedOrigins: []string{"example.com"},
origins: []string{""},
allowed: false,
},
{
name: "allowed regexp is a match",
allowedOrigins: []string{"example.com"},
origins: []string{"http://example.com", "example.com"},
allowed: true,
},
{
name: "allowed regexp is not a match",
allowedOrigins: []string{"example.com"},
origins: []string{"http://not-allowed.com", "not-allowed.com"},
allowed: false,
},
{
name: "allowed list with multiple regex",
allowedOrigins: []string{"not-matching.com", "example.com"},
origins: []string{"http://example.com", "example.com"},
allowed: true,
},
{
name: "wildcard matching",
allowedOrigins: []string{".*"},
origins: []string{"http://example.com", "example.com"},
allowed: true,
},
} }
for _, item := range table { for _, test := range tests {
for _, origin := range test.origins {
name := fmt.Sprintf("%s/origin/%s", test.name, origin)
t.Run(name, func(t *testing.T) {
var handlerInvoked int
handler := WithCORS( handler := WithCORS(
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}), http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
item.allowedOrigins, nil, nil, nil, "true", handlerInvoked++
}),
test.allowedOrigins, nil, nil, nil, "true",
) )
var response *http.Response var response *http.Response
func() { func() {
@ -51,16 +90,20 @@ func TestCORSAllowedOrigins(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
request.Header.Set("Origin", item.origin) request.Header.Set("Origin", origin)
client := http.Client{} client := http.Client{}
response, err = client.Do(request) response, err = client.Do(request)
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
}() }()
if item.allowed { if handlerInvoked != 1 {
if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) { t.Errorf("Expected the handler to be invoked once, but got: %d", handlerInvoked)
t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin")) }
if test.allowed {
if !reflect.DeepEqual(origin, response.Header.Get("Access-Control-Allow-Origin")) {
t.Errorf("Expected %#v, Got %#v", origin, response.Header.Get("Access-Control-Allow-Origin"))
} }
if response.Header.Get("Access-Control-Allow-Credentials") == "" { if response.Header.Get("Access-Control-Allow-Credentials") == "" {
@ -99,6 +142,8 @@ func TestCORSAllowedOrigins(t *testing.T) {
t.Errorf("Expected Date in Access-Control-Expose-Headers header") t.Errorf("Expected Date in Access-Control-Expose-Headers header")
} }
} }
})
}
} }
} }