mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-01 07:47:56 +00:00
Merge pull request #114190 from tkashem/cors-refactor
refactor CORS handler
This commit is contained in:
commit
0f7409a230
@ -38,30 +38,38 @@ func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMetho
|
|||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
allowedOriginPatternsREs := allowedOriginRegexps(allowedOriginPatterns)
|
allowedOriginPatternsREs := allowedOriginRegexps(allowedOriginPatterns)
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
origin := req.Header.Get("Origin")
|
|
||||||
if origin != "" {
|
|
||||||
allowed := false
|
|
||||||
for _, re := range allowedOriginPatternsREs {
|
|
||||||
if allowed = re.MatchString(origin); allowed {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if allowed {
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
||||||
// Set defaults for methods and headers if nothing was passed
|
// Set defaults for methods and headers if nothing was passed
|
||||||
if allowedMethods == nil {
|
if allowedMethods == nil {
|
||||||
allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"}
|
allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"}
|
||||||
}
|
}
|
||||||
|
allowMethodsResponseHeader := strings.Join(allowedMethods, ", ")
|
||||||
|
|
||||||
if allowedHeaders == nil {
|
if allowedHeaders == nil {
|
||||||
allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"}
|
allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"}
|
||||||
}
|
}
|
||||||
|
allowHeadersResponseHeader := strings.Join(allowedHeaders, ", ")
|
||||||
|
|
||||||
if exposedHeaders == nil {
|
if exposedHeaders == nil {
|
||||||
exposedHeaders = []string{"Date"}
|
exposedHeaders = []string{"Date"}
|
||||||
}
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", "))
|
exposeHeadersResponseHeader := strings.Join(exposedHeaders, ", ")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", "))
|
|
||||||
w.Header().Set("Access-Control-Expose-Headers", strings.Join(exposedHeaders, ", "))
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
origin := req.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !isOriginAllowed(origin, allowedOriginPatternsREs) {
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", allowMethodsResponseHeader)
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", allowHeadersResponseHeader)
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", exposeHeadersResponseHeader)
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", allowCredentials)
|
w.Header().Set("Access-Control-Allow-Credentials", allowCredentials)
|
||||||
|
|
||||||
// Stop here if its a preflight OPTIONS request
|
// Stop here if its a preflight OPTIONS request
|
||||||
@ -69,13 +77,37 @@ func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMetho
|
|||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
// Dispatch to the next handler
|
// Dispatch to the next handler
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isOriginAllowed returns true if the given origin header in the
|
||||||
|
// request is allowed CORS.
|
||||||
|
//
|
||||||
|
// From https://www.rfc-editor.org/rfc/rfc6454#page-13
|
||||||
|
//
|
||||||
|
// a) The origin header can contain host and/or port
|
||||||
|
// serialized-origin = scheme "://" host [ ":" port ]
|
||||||
|
//
|
||||||
|
// b) In some cases, a number of origins contribute to causing the user
|
||||||
|
// agents to issue an HTTP request. In those cases, the user agent MAY
|
||||||
|
// list all the origins in the Origin header field. For example, if the
|
||||||
|
// HTTP request was initially issued by one origin but then later
|
||||||
|
// redirected by another origin, the user agent MAY inform the server
|
||||||
|
// that two origins were involved in causing the user agent to issue the
|
||||||
|
// request
|
||||||
|
// origin-list = serialized-origin *( SP serialized-origin )
|
||||||
|
func isOriginAllowed(originHeader string, allowedOriginPatternsREs []*regexp.Regexp) bool {
|
||||||
|
for _, re := range allowedOriginPatternsREs {
|
||||||
|
if re.MatchString(originHeader) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func allowedOriginRegexps(allowedOrigins []string) []*regexp.Regexp {
|
func allowedOriginRegexps(allowedOrigins []string) []*regexp.Regexp {
|
||||||
res, err := compileRegexps(allowedOrigins)
|
res, err := compileRegexps(allowedOrigins)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
package filters
|
package filters
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -25,22 +26,60 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCORSAllowedOrigins(t *testing.T) {
|
func TestCORSAllowedOrigins(t *testing.T) {
|
||||||
table := []struct {
|
tests := []struct {
|
||||||
|
name string
|
||||||
allowedOrigins []string
|
allowedOrigins []string
|
||||||
origin string
|
origins []string
|
||||||
allowed bool
|
allowed bool
|
||||||
}{
|
}{
|
||||||
{[]string{}, "example.com", false},
|
{
|
||||||
{[]string{"example.com"}, "example.com", true},
|
name: "allowed origins list is empty",
|
||||||
{[]string{"example.com"}, "not-allowed.com", false},
|
allowedOrigins: []string{},
|
||||||
{[]string{"not-matching.com", "example.com"}, "example.com", true},
|
origins: []string{"example.com"},
|
||||||
{[]string{".*"}, "example.com", true},
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "origin request header not set",
|
||||||
|
allowedOrigins: []string{"example.com"},
|
||||||
|
origins: []string{""},
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allowed regexp is a match",
|
||||||
|
allowedOrigins: []string{"example.com"},
|
||||||
|
origins: []string{"http://example.com", "example.com"},
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allowed regexp is not a match",
|
||||||
|
allowedOrigins: []string{"example.com"},
|
||||||
|
origins: []string{"http://not-allowed.com", "not-allowed.com"},
|
||||||
|
allowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allowed list with multiple regex",
|
||||||
|
allowedOrigins: []string{"not-matching.com", "example.com"},
|
||||||
|
origins: []string{"http://example.com", "example.com"},
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard matching",
|
||||||
|
allowedOrigins: []string{".*"},
|
||||||
|
origins: []string{"http://example.com", "example.com"},
|
||||||
|
allowed: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, item := range table {
|
for _, test := range tests {
|
||||||
|
for _, origin := range test.origins {
|
||||||
|
name := fmt.Sprintf("%s/origin/%s", test.name, origin)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
var handlerInvoked int
|
||||||
handler := WithCORS(
|
handler := WithCORS(
|
||||||
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}),
|
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||||
item.allowedOrigins, nil, nil, nil, "true",
|
handlerInvoked++
|
||||||
|
}),
|
||||||
|
test.allowedOrigins, nil, nil, nil, "true",
|
||||||
)
|
)
|
||||||
var response *http.Response
|
var response *http.Response
|
||||||
func() {
|
func() {
|
||||||
@ -51,16 +90,20 @@ func TestCORSAllowedOrigins(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
request.Header.Set("Origin", item.origin)
|
request.Header.Set("Origin", origin)
|
||||||
client := http.Client{}
|
client := http.Client{}
|
||||||
response, err = client.Do(request)
|
response, err = client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if item.allowed {
|
if handlerInvoked != 1 {
|
||||||
if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) {
|
t.Errorf("Expected the handler to be invoked once, but got: %d", handlerInvoked)
|
||||||
t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin"))
|
}
|
||||||
|
|
||||||
|
if test.allowed {
|
||||||
|
if !reflect.DeepEqual(origin, response.Header.Get("Access-Control-Allow-Origin")) {
|
||||||
|
t.Errorf("Expected %#v, Got %#v", origin, response.Header.Get("Access-Control-Allow-Origin"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.Header.Get("Access-Control-Allow-Credentials") == "" {
|
if response.Header.Get("Access-Control-Allow-Credentials") == "" {
|
||||||
@ -99,6 +142,8 @@ func TestCORSAllowedOrigins(t *testing.T) {
|
|||||||
t.Errorf("Expected Date in Access-Control-Expose-Headers header")
|
t.Errorf("Expected Date in Access-Control-Expose-Headers header")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user