diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/cors.go b/staging/src/k8s.io/apiserver/pkg/server/filters/cors.go index 29c46e4c793..ca078cb0520 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/cors.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/cors.go @@ -38,44 +38,76 @@ func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMetho return handler } allowedOriginPatternsREs := allowedOriginRegexps(allowedOriginPatterns) + + // Set defaults for methods and headers if nothing was passed + if allowedMethods == nil { + allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"} + } + allowMethodsResponseHeader := strings.Join(allowedMethods, ", ") + + if allowedHeaders == nil { + allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} + } + allowHeadersResponseHeader := strings.Join(allowedHeaders, ", ") + + if exposedHeaders == nil { + exposedHeaders = []string{"Date"} + } + exposeHeadersResponseHeader := strings.Join(exposedHeaders, ", ") + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { origin := req.Header.Get("Origin") - if origin != "" { - allowed := false - for _, re := range allowedOriginPatternsREs { - if allowed = re.MatchString(origin); allowed { - break - } - } - if allowed { - w.Header().Set("Access-Control-Allow-Origin", origin) - // Set defaults for methods and headers if nothing was passed - if allowedMethods == nil { - allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"} - } - if allowedHeaders == nil { - allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} - } - if exposedHeaders == nil { - exposedHeaders = []string{"Date"} - } - w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) - w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) - w.Header().Set("Access-Control-Expose-Headers", strings.Join(exposedHeaders, ", ")) - w.Header().Set("Access-Control-Allow-Credentials", allowCredentials) - - // Stop here if its a preflight OPTIONS request - if req.Method == "OPTIONS" { - w.WriteHeader(http.StatusNoContent) - return - } - } + if origin == "" { + handler.ServeHTTP(w, req) + return } + if !isOriginAllowed(origin, allowedOriginPatternsREs) { + handler.ServeHTTP(w, req) + return + } + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", allowMethodsResponseHeader) + w.Header().Set("Access-Control-Allow-Headers", allowHeadersResponseHeader) + w.Header().Set("Access-Control-Expose-Headers", exposeHeadersResponseHeader) + w.Header().Set("Access-Control-Allow-Credentials", allowCredentials) + + // Stop here if its a preflight OPTIONS request + if req.Method == "OPTIONS" { + w.WriteHeader(http.StatusNoContent) + return + } + // Dispatch to the next handler handler.ServeHTTP(w, req) }) } +// isOriginAllowed returns true if the given origin header in the +// request is allowed CORS. +// +// From https://www.rfc-editor.org/rfc/rfc6454#page-13 +// +// a) The origin header can contain host and/or port +// serialized-origin = scheme "://" host [ ":" port ] +// +// b) In some cases, a number of origins contribute to causing the user +// agents to issue an HTTP request. In those cases, the user agent MAY +// list all the origins in the Origin header field. For example, if the +// HTTP request was initially issued by one origin but then later +// redirected by another origin, the user agent MAY inform the server +// that two origins were involved in causing the user agent to issue the +// request +// origin-list = serialized-origin *( SP serialized-origin ) +func isOriginAllowed(originHeader string, allowedOriginPatternsREs []*regexp.Regexp) bool { + for _, re := range allowedOriginPatternsREs { + if re.MatchString(originHeader) { + return true + } + } + return false +} + func allowedOriginRegexps(allowedOrigins []string) []*regexp.Regexp { res, err := compileRegexps(allowedOrigins) if err != nil { diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/cors_test.go b/staging/src/k8s.io/apiserver/pkg/server/filters/cors_test.go index 8b39c9497d0..afab083a883 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/cors_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/cors_test.go @@ -17,6 +17,7 @@ limitations under the License. package filters import ( + "fmt" "net/http" "net/http/httptest" "reflect" @@ -25,79 +26,123 @@ import ( ) func TestCORSAllowedOrigins(t *testing.T) { - table := []struct { + tests := []struct { + name string allowedOrigins []string - origin string + origins []string allowed bool }{ - {[]string{}, "example.com", false}, - {[]string{"example.com"}, "example.com", true}, - {[]string{"example.com"}, "not-allowed.com", false}, - {[]string{"not-matching.com", "example.com"}, "example.com", true}, - {[]string{".*"}, "example.com", true}, + { + name: "allowed origins list is empty", + allowedOrigins: []string{}, + origins: []string{"example.com"}, + allowed: false, + }, + { + name: "origin request header not set", + allowedOrigins: []string{"example.com"}, + origins: []string{""}, + allowed: false, + }, + { + name: "allowed regexp is a match", + allowedOrigins: []string{"example.com"}, + origins: []string{"http://example.com", "example.com"}, + allowed: true, + }, + { + name: "allowed regexp is not a match", + allowedOrigins: []string{"example.com"}, + origins: []string{"http://not-allowed.com", "not-allowed.com"}, + allowed: false, + }, + { + name: "allowed list with multiple regex", + allowedOrigins: []string{"not-matching.com", "example.com"}, + origins: []string{"http://example.com", "example.com"}, + allowed: true, + }, + { + name: "wildcard matching", + allowedOrigins: []string{".*"}, + origins: []string{"http://example.com", "example.com"}, + allowed: true, + }, } - for _, item := range table { - handler := WithCORS( - http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}), - item.allowedOrigins, nil, nil, nil, "true", - ) - var response *http.Response - func() { - server := httptest.NewServer(handler) - defer server.Close() + for _, test := range tests { + for _, origin := range test.origins { + name := fmt.Sprintf("%s/origin/%s", test.name, origin) + t.Run(name, func(t *testing.T) { + var handlerInvoked int + handler := WithCORS( + http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + handlerInvoked++ + }), + test.allowedOrigins, nil, nil, nil, "true", + ) + var response *http.Response + func() { + server := httptest.NewServer(handler) + defer server.Close() - request, err := http.NewRequest("GET", server.URL+"/version", nil) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - request.Header.Set("Origin", item.origin) - client := http.Client{} - response, err = client.Do(request) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - }() - if item.allowed { - if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) { - t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin")) - } + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", origin) + client := http.Client{} + response, err = client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + if handlerInvoked != 1 { + t.Errorf("Expected the handler to be invoked once, but got: %d", handlerInvoked) + } - if response.Header.Get("Access-Control-Allow-Credentials") == "" { - t.Errorf("Expected Access-Control-Allow-Credentials header to be set") - } + if test.allowed { + if !reflect.DeepEqual(origin, response.Header.Get("Access-Control-Allow-Origin")) { + t.Errorf("Expected %#v, Got %#v", origin, response.Header.Get("Access-Control-Allow-Origin")) + } - if response.Header.Get("Access-Control-Allow-Headers") == "" { - t.Errorf("Expected Access-Control-Allow-Headers header to be set") - } + if response.Header.Get("Access-Control-Allow-Credentials") == "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to be set") + } - if response.Header.Get("Access-Control-Allow-Methods") == "" { - t.Errorf("Expected Access-Control-Allow-Methods header to be set") - } + if response.Header.Get("Access-Control-Allow-Headers") == "" { + t.Errorf("Expected Access-Control-Allow-Headers header to be set") + } - if response.Header.Get("Access-Control-Expose-Headers") != "Date" { - t.Errorf("Expected Date in Access-Control-Expose-Headers header") - } - } else { - if response.Header.Get("Access-Control-Allow-Origin") != "" { - t.Errorf("Expected Access-Control-Allow-Origin header to not be set") - } + if response.Header.Get("Access-Control-Allow-Methods") == "" { + t.Errorf("Expected Access-Control-Allow-Methods header to be set") + } - if response.Header.Get("Access-Control-Allow-Credentials") != "" { - t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") - } + if response.Header.Get("Access-Control-Expose-Headers") != "Date" { + t.Errorf("Expected Date in Access-Control-Expose-Headers header") + } + } else { + if response.Header.Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to not be set") + } - if response.Header.Get("Access-Control-Allow-Headers") != "" { - t.Errorf("Expected Access-Control-Allow-Headers header to not be set") - } + if response.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") + } - if response.Header.Get("Access-Control-Allow-Methods") != "" { - t.Errorf("Expected Access-Control-Allow-Methods header to not be set") - } + if response.Header.Get("Access-Control-Allow-Headers") != "" { + t.Errorf("Expected Access-Control-Allow-Headers header to not be set") + } - if response.Header.Get("Access-Control-Expose-Headers") == "Date" { - t.Errorf("Expected Date in Access-Control-Expose-Headers header") - } + if response.Header.Get("Access-Control-Allow-Methods") != "" { + t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + } + + if response.Header.Get("Access-Control-Expose-Headers") == "Date" { + t.Errorf("Expected Date in Access-Control-Expose-Headers header") + } + } + }) } } }