From e3170d03723ed8b4106df3231f5ac8aa2247eb66 Mon Sep 17 00:00:00 2001 From: Fabiano Franz Date: Tue, 1 Nov 2016 13:07:57 -0200 Subject: [PATCH] Allow PATCH in an API CORS setup --- pkg/genericapiserver/filters/cors.go | 2 +- pkg/genericapiserver/filters/cors_test.go | 59 +++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/pkg/genericapiserver/filters/cors.go b/pkg/genericapiserver/filters/cors.go index c463047879e..4d6ee35820c 100644 --- a/pkg/genericapiserver/filters/cors.go +++ b/pkg/genericapiserver/filters/cors.go @@ -53,7 +53,7 @@ func WithCORS(handler http.Handler, allowedOriginPatterns []string, allowedMetho w.Header().Set("Access-Control-Allow-Origin", origin) // Set defaults for methods and headers if nothing was passed if allowedMethods == nil { - allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE"} + allowedMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"} } if allowedHeaders == nil { allowedHeaders = []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "X-Requested-With", "If-Modified-Since"} diff --git a/pkg/genericapiserver/filters/cors_test.go b/pkg/genericapiserver/filters/cors_test.go index 817c01a6f2c..bbd0b75a0e5 100644 --- a/pkg/genericapiserver/filters/cors_test.go +++ b/pkg/genericapiserver/filters/cors_test.go @@ -20,6 +20,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "testing" ) @@ -72,6 +73,7 @@ func TestCORSAllowedOrigins(t *testing.T) { if response.Header.Get("Access-Control-Allow-Methods") == "" { t.Errorf("Expected Access-Control-Allow-Methods header to be set") } + if response.Header.Get("Access-Control-Expose-Headers") != "Date" { t.Errorf("Expected Date in Access-Control-Expose-Headers header") } @@ -91,9 +93,66 @@ func TestCORSAllowedOrigins(t *testing.T) { if response.Header.Get("Access-Control-Allow-Methods") != "" { t.Errorf("Expected Access-Control-Allow-Methods header to not be set") } + if response.Header.Get("Access-Control-Expose-Headers") == "Date" { t.Errorf("Expected Date in Access-Control-Expose-Headers header") } } } } + +func TestCORSAllowedMethods(t *testing.T) { + tests := []struct { + allowedMethods []string + method string + allowed bool + }{ + {nil, "POST", true}, + {nil, "GET", true}, + {nil, "OPTIONS", true}, + {nil, "PUT", true}, + {nil, "DELETE", true}, + {nil, "PATCH", true}, + {[]string{"GET", "POST"}, "PATCH", false}, + } + + allowsMethod := func(res *http.Response, method string) bool { + allowedMethods := strings.Split(res.Header.Get("Access-Control-Allow-Methods"), ",") + for _, allowedMethod := range allowedMethods { + if strings.TrimSpace(allowedMethod) == method { + return true + } + } + return false + } + + for _, test := range tests { + handler := WithCORS( + http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}), + []string{".*"}, test.allowedMethods, nil, nil, "true", + ) + server := httptest.NewServer(handler) + defer server.Close() + client := http.Client{} + + request, err := http.NewRequest(test.method, server.URL+"/version", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + request.Header.Set("Origin", "allowed.com") + + response, err := client.Do(request) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + methodAllowed := allowsMethod(response, test.method) + switch { + case test.allowed && !methodAllowed: + t.Errorf("Expected %v to be allowed, Got only %#v", test.method, response.Header.Get("Access-Control-Allow-Methods")) + case !test.allowed && methodAllowed: + t.Errorf("Unexpected allowed method %v, Expected only %#v", test.method, response.Header.Get("Access-Control-Allow-Methods")) + } + } + +}