From 8b4ca9c2a7913d2dc6c951cc8ba3a1a00c811e2a Mon Sep 17 00:00:00 2001 From: Jessica Forrester Date: Wed, 3 Sep 2014 14:33:52 -0400 Subject: [PATCH] Move CORS handler wrapping into cmd/apiserver and switch config flag to a list of allowed origins --- cmd/apiserver/apiserver.go | 33 ++++++++------ cmd/integration/integration.go | 2 +- hack/local-up-cluster.sh | 5 ++- pkg/apiserver/apiserver.go | 6 +-- pkg/apiserver/apiserver_test.go | 71 ++++++++++++++++++++++--------- pkg/apiserver/handlers.go | 19 +++++++-- pkg/apiserver/minionproxy_test.go | 2 +- pkg/apiserver/operation_test.go | 4 +- pkg/apiserver/redirect_test.go | 2 +- pkg/apiserver/watch_test.go | 6 +-- test/integration/client_test.go | 2 +- 11 files changed, 102 insertions(+), 50 deletions(-) diff --git a/cmd/apiserver/apiserver.go b/cmd/apiserver/apiserver.go index bebc3f91696..96fa3e44d20 100644 --- a/cmd/apiserver/apiserver.go +++ b/cmd/apiserver/apiserver.go @@ -36,22 +36,25 @@ import ( ) var ( - port = flag.Uint("port", 8080, "The port to listen on. Default 8080.") - address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1") - apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'") - enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]") - cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.") - cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.") - minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs") - minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.") - healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]") - minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]") - etcdServerList, machineList util.StringList + port = flag.Uint("port", 8080, "The port to listen on. Default 8080.") + address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1") + apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'") + enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]") + cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.") + cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.") + minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs") + minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.") + healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]") + minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]") + etcdServerList util.StringList + machineList util.StringList + corsAllowedOriginList util.StringList ) func init() { flag.Var(&etcdServerList, "etcd_servers", "List of etcd servers to watch (http://ip:port), comma separated") flag.Var(&machineList, "machines", "List of machines to schedule onto, comma separated.") + flag.Var(&corsAllowedOriginList, "cors_allowed_origins", "List of allowed origins for CORS, comma separated. An allowed origin can be a regular expression to support subdomain matching. If this list is empty CORS will not be enabled.") } func verifyMinionFlags() { @@ -133,9 +136,15 @@ func main() { }) storage, codec := m.API_v1beta1() + + handler := apiserver.Handle(storage, codec, *apiPrefix) + if len(corsAllowedOriginList) > 0 { + handler = apiserver.CORS(handler, corsAllowedOriginList, nil, nil, "true") + } + s := &http.Server{ Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))), - Handler: apiserver.Handle(storage, codec, *apiPrefix, *enableCORS), + Handler: handler, ReadTimeout: 5 * time.Minute, WriteTimeout: 5 * time.Minute, MaxHeaderBytes: 1 << 20, diff --git a/cmd/integration/integration.go b/cmd/integration/integration.go index 43c17dd7da9..14fe09fdd33 100644 --- a/cmd/integration/integration.go +++ b/cmd/integration/integration.go @@ -116,7 +116,7 @@ func startComponents(manifestURL string) (apiServerURL string) { PodInfoGetter: fakePodInfoGetter{}, }) storage, codec := m.API_v1beta1() - handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1", false) + handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1") // Scheduler scheduler.New((&factory.ConfigFactory{cl}).Create()).Run() diff --git a/hack/local-up-cluster.sh b/hack/local-up-cluster.sh index ac564515c94..875da775b67 100755 --- a/hack/local-up-cluster.sh +++ b/hack/local-up-cluster.sh @@ -39,7 +39,8 @@ set +e API_PORT=${API_PORT:-8080} API_HOST=${API_HOST:-127.0.0.1} -API_ENABLE_CORS=${API_ENABLE_CORS:-false} +# By default only allow CORS for requests on localhost +API_CORS_ALLOWED_ORIGINS=${API_CORS_ALLOWED_ORIGINS:-127.0.0.1:.*,localhost:.*} KUBELET_PORT=${KUBELET_PORT:-10250} GO_OUT=$(dirname $0)/../_output/go/bin @@ -50,7 +51,7 @@ APISERVER_LOG=/tmp/apiserver.log --port="${API_PORT}" \ --etcd_servers="http://127.0.0.1:4001" \ --machines="127.0.0.1" \ - --enable_cors="${API_ENABLE_CORS}" >"${APISERVER_LOG}" 2>&1 & + --cors_allowed_origins="${API_CORS_ALLOWED_ORIGINS}" >"${APISERVER_LOG}" 2>&1 & APISERVER_PID=$! CTLRMGR_LOG=/tmp/controller-manager.log diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 4d6b57d9d10..908774cd3ce 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -55,11 +55,7 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string) mux := http.NewServeMux() group.InstallREST(mux, prefix) InstallSupport(mux) - handler := RecoverPanics(mux) - if enableCORS { - handler = CORS(handler, []string{".*"}, nil, nil, "true") - } - return &defaultAPIServer{handler, group} + return &defaultAPIServer{RecoverPanics(mux), group} } // APIGroup is a http.Handler that exposes multiple RESTStorage objects diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index e0f8806959f..6f970c2eccc 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -191,7 +191,7 @@ func TestNotFound(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &SimpleRESTStorage{}, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} for k, v := range cases { @@ -212,7 +212,7 @@ func TestNotFound(t *testing.T) { } func TestVersion(t *testing.T) { - handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", false) + handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -241,7 +241,7 @@ func TestSimpleList(t *testing.T) { storage := map[string]RESTStorage{} simpleStorage := SimpleRESTStorage{} storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -260,7 +260,7 @@ func TestErrorList(t *testing.T) { errors: map[string]error{"list": fmt.Errorf("test Error")}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -284,7 +284,7 @@ func TestNonEmptyList(t *testing.T) { }, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple") @@ -319,7 +319,7 @@ func TestGet(t *testing.T) { }, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple/id") @@ -340,7 +340,7 @@ func TestGetMissing(t *testing.T) { errors: map[string]error{"get": apierrs.NewNotFound("simple", "id")}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) resp, err := http.Get(server.URL + "/prefix/version/simple/id") @@ -358,7 +358,7 @@ func TestDelete(t *testing.T) { simpleStorage := SimpleRESTStorage{} ID := "id" storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -380,7 +380,7 @@ func TestDeleteMissing(t *testing.T) { errors: map[string]error{"delete": apierrs.NewNotFound("simple", ID)}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -400,7 +400,7 @@ func TestUpdate(t *testing.T) { simpleStorage := SimpleRESTStorage{} ID := "id" storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) item := &Simple{ @@ -430,7 +430,7 @@ func TestUpdateMissing(t *testing.T) { errors: map[string]error{"update": apierrs.NewNotFound("simple", ID)}, } storage["simple"] = &simpleStorage - handler := Handle(storage, codec, "/prefix/version", false) + handler := Handle(storage, codec, "/prefix/version") server := httptest.NewServer(handler) item := &Simple{ @@ -457,7 +457,7 @@ func TestCreate(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} @@ -498,7 +498,7 @@ func TestCreateNotFound(t *testing.T) { // See https://github.com/GoogleCloudPlatform/kubernetes/pull/486#discussion_r15037092. errors: map[string]error{"create": apierrs.NewNotFound("simple", "id")}, }, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -540,7 +540,7 @@ func TestSyncCreate(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &storage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -609,7 +609,7 @@ func TestAsyncDelayReturnsError(t *testing.T) { return nil, apierrs.NewAlreadyExists("foo", "bar") }, } - handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false) + handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = time.Millisecond / 2 server := httptest.NewServer(handler) @@ -627,7 +627,7 @@ func TestAsyncCreateError(t *testing.T) { return nil, apierrs.NewAlreadyExists("foo", "bar") }, } - handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false) + handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) @@ -721,7 +721,7 @@ func TestSyncCreateTimeout(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": &storage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) simple := &Simple{Name: "foo"} @@ -732,8 +732,8 @@ func TestSyncCreateTimeout(t *testing.T) { } } -func TestEnableCORS(t *testing.T) { - handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", true) +func TestCORSAllowedOrigin(t *testing.T) { + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true") server := httptest.NewServer(handler) client := http.Client{} @@ -764,3 +764,36 @@ func TestEnableCORS(t *testing.T) { t.Errorf("Expected Access-Control-Allow-Methods header to be set") } } + +func TestCORSUnallowedOrigin(t *testing.T) { + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true") + server := httptest.NewServer(handler) + client := http.Client{} + + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", "not-allowed.com") + + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if response.Header.Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") != "" { + t.Errorf("Expected Access-Control-Allow-Headers header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") != "" { + t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + } +} diff --git a/pkg/apiserver/handlers.go b/pkg/apiserver/handlers.go index ac79f4d84d6..4554e15f25e 100644 --- a/pkg/apiserver/handlers.go +++ b/pkg/apiserver/handlers.go @@ -24,6 +24,7 @@ import ( "strings" "github.com/GoogleCloudPlatform/kubernetes/pkg/httplog" + "github.com/GoogleCloudPlatform/kubernetes/pkg/util" "github.com/golang/glog" ) @@ -58,13 +59,25 @@ func RecoverPanics(handler http.Handler) http.Handler { // For a more detailed implementation use https://github.com/martini-contrib/cors // or implement CORS at your proxy layer // Pass nil for allowedMethods and allowedHeaders to use the defaults -func CORS(handler http.Handler, allowedOriginPatterns []string, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { +func CORS(handler http.Handler, allowedOriginPatterns util.StringList, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler { + // Compile the regular expressions once upfront + allowedOriginRegexps := []*regexp.Regexp{} + for _, allowedOrigin := range allowedOriginPatterns { + allowedOriginRegexp, err := regexp.Compile(allowedOrigin) + if err != nil { + glog.Fatalf("Invalid CORS allowed origin regexp: %v", err) + } + allowedOriginRegexps = append(allowedOriginRegexps, allowedOriginRegexp) + } + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { origin := req.Header.Get("Origin") if origin != "" { allowed := false - for _, pattern := range allowedOriginPatterns { - allowed, _ = regexp.MatchString(pattern, origin) + for _, pattern := range allowedOriginRegexps { + if allowed = pattern.MatchString(origin); allowed { + break + } } if allowed { w.Header().Set("Access-Control-Allow-Origin", origin) diff --git a/pkg/apiserver/minionproxy_test.go b/pkg/apiserver/minionproxy_test.go index 26c6820dff8..1ceeedf3a63 100644 --- a/pkg/apiserver/minionproxy_test.go +++ b/pkg/apiserver/minionproxy_test.go @@ -127,7 +127,7 @@ func TestApiServerMinionProxy(t *testing.T) { proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte(req.URL.Path)) })) - server := httptest.NewServer(Handle(nil, nil, "/prefix", false)) + server := httptest.NewServer(Handle(nil, nil, "/prefix")) proxy, _ := url.Parse(proxyServer.URL) resp, err := http.Get(fmt.Sprintf("%s/proxy/minion/%s%s", server.URL, proxy.Host, "/test")) if err != nil { diff --git a/pkg/apiserver/operation_test.go b/pkg/apiserver/operation_test.go index 203a415232b..856fe764389 100644 --- a/pkg/apiserver/operation_test.go +++ b/pkg/apiserver/operation_test.go @@ -107,7 +107,7 @@ func TestOperationsList(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} @@ -163,7 +163,7 @@ func TestOpGet(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") handler.(*defaultAPIServer).group.handler.asyncOpWait = 0 server := httptest.NewServer(handler) client := http.Client{} diff --git a/pkg/apiserver/redirect_test.go b/pkg/apiserver/redirect_test.go index 5730b75b95e..6425f920d8c 100644 --- a/pkg/apiserver/redirect_test.go +++ b/pkg/apiserver/redirect_test.go @@ -30,7 +30,7 @@ func TestRedirect(t *testing.T) { } handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) dontFollow := errors.New("don't follow") diff --git a/pkg/apiserver/watch_test.go b/pkg/apiserver/watch_test.go index ff5bbc864c5..70027c1d44a 100644 --- a/pkg/apiserver/watch_test.go +++ b/pkg/apiserver/watch_test.go @@ -44,7 +44,7 @@ func TestWatchWebsocket(t *testing.T) { _ = ResourceWatcher(simpleStorage) // Give compile error if this doesn't work. handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) dest, _ := url.Parse(server.URL) @@ -90,7 +90,7 @@ func TestWatchHTTP(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) client := http.Client{} @@ -147,7 +147,7 @@ func TestWatchParamParsing(t *testing.T) { simpleStorage := &SimpleRESTStorage{} handler := Handle(map[string]RESTStorage{ "foo": simpleStorage, - }, codec, "/prefix/version", false) + }, codec, "/prefix/version") server := httptest.NewServer(handler) dest, _ := url.Parse(server.URL) diff --git a/test/integration/client_test.go b/test/integration/client_test.go index f8f3e356455..5d539d11e29 100644 --- a/test/integration/client_test.go +++ b/test/integration/client_test.go @@ -41,7 +41,7 @@ func TestClient(t *testing.T) { }) storage, codec := m.API_v1beta1() - s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/", false)) + s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/")) client := client.NewOrDie(s.URL, nil)