Move CORS handler wrapping into cmd/apiserver and switch config flag to a list of allowed origins

This commit is contained in:
Jessica Forrester 2014-09-03 14:33:52 -04:00
parent 8723eece49
commit 8b4ca9c2a7
11 changed files with 102 additions and 50 deletions

View File

@ -36,22 +36,25 @@ import (
)
var (
port = flag.Uint("port", 8080, "The port to listen on. Default 8080.")
address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1")
apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'")
enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]")
cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.")
cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.")
minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs")
minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.")
healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]")
minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]")
etcdServerList, machineList util.StringList
port = flag.Uint("port", 8080, "The port to listen on. Default 8080.")
address = flag.String("address", "127.0.0.1", "The address on the local server to listen to. Default 127.0.0.1")
apiPrefix = flag.String("api_prefix", "/api/v1beta1", "The prefix for API requests on the server. Default '/api/v1beta1'")
enableCORS = flag.Bool("enable_cors", false, "If true, the basic CORS implementation will be enabled. [default false]")
cloudProvider = flag.String("cloud_provider", "", "The provider for cloud services. Empty string for no provider.")
cloudConfigFile = flag.String("cloud_config", "", "The path to the cloud provider configuration file. Empty string for no configuration file.")
minionRegexp = flag.String("minion_regexp", "", "If non empty, and -cloud_provider is specified, a regular expression for matching minion VMs")
minionPort = flag.Uint("minion_port", 10250, "The port at which kubelet will be listening on the minions.")
healthCheckMinions = flag.Bool("health_check_minions", true, "If true, health check minions and filter unhealthy ones. [default true]")
minionCacheTTL = flag.Duration("minion_cache_ttl", 30*time.Second, "Duration of time to cache minion information. [default 30 seconds]")
etcdServerList util.StringList
machineList util.StringList
corsAllowedOriginList util.StringList
)
func init() {
flag.Var(&etcdServerList, "etcd_servers", "List of etcd servers to watch (http://ip:port), comma separated")
flag.Var(&machineList, "machines", "List of machines to schedule onto, comma separated.")
flag.Var(&corsAllowedOriginList, "cors_allowed_origins", "List of allowed origins for CORS, comma separated. An allowed origin can be a regular expression to support subdomain matching. If this list is empty CORS will not be enabled.")
}
func verifyMinionFlags() {
@ -133,9 +136,15 @@ func main() {
})
storage, codec := m.API_v1beta1()
handler := apiserver.Handle(storage, codec, *apiPrefix)
if len(corsAllowedOriginList) > 0 {
handler = apiserver.CORS(handler, corsAllowedOriginList, nil, nil, "true")
}
s := &http.Server{
Addr: net.JoinHostPort(*address, strconv.Itoa(int(*port))),
Handler: apiserver.Handle(storage, codec, *apiPrefix, *enableCORS),
Handler: handler,
ReadTimeout: 5 * time.Minute,
WriteTimeout: 5 * time.Minute,
MaxHeaderBytes: 1 << 20,

View File

@ -116,7 +116,7 @@ func startComponents(manifestURL string) (apiServerURL string) {
PodInfoGetter: fakePodInfoGetter{},
})
storage, codec := m.API_v1beta1()
handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1", false)
handler.delegate = apiserver.Handle(storage, codec, "/api/v1beta1")
// Scheduler
scheduler.New((&factory.ConfigFactory{cl}).Create()).Run()

View File

@ -39,7 +39,8 @@ set +e
API_PORT=${API_PORT:-8080}
API_HOST=${API_HOST:-127.0.0.1}
API_ENABLE_CORS=${API_ENABLE_CORS:-false}
# By default only allow CORS for requests on localhost
API_CORS_ALLOWED_ORIGINS=${API_CORS_ALLOWED_ORIGINS:-127.0.0.1:.*,localhost:.*}
KUBELET_PORT=${KUBELET_PORT:-10250}
GO_OUT=$(dirname $0)/../_output/go/bin
@ -50,7 +51,7 @@ APISERVER_LOG=/tmp/apiserver.log
--port="${API_PORT}" \
--etcd_servers="http://127.0.0.1:4001" \
--machines="127.0.0.1" \
--enable_cors="${API_ENABLE_CORS}" >"${APISERVER_LOG}" 2>&1 &
--cors_allowed_origins="${API_CORS_ALLOWED_ORIGINS}" >"${APISERVER_LOG}" 2>&1 &
APISERVER_PID=$!
CTLRMGR_LOG=/tmp/controller-manager.log

View File

@ -55,11 +55,7 @@ func Handle(storage map[string]RESTStorage, codec runtime.Codec, prefix string)
mux := http.NewServeMux()
group.InstallREST(mux, prefix)
InstallSupport(mux)
handler := RecoverPanics(mux)
if enableCORS {
handler = CORS(handler, []string{".*"}, nil, nil, "true")
}
return &defaultAPIServer{handler, group}
return &defaultAPIServer{RecoverPanics(mux), group}
}
// APIGroup is a http.Handler that exposes multiple RESTStorage objects

View File

@ -191,7 +191,7 @@ func TestNotFound(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": &SimpleRESTStorage{},
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
server := httptest.NewServer(handler)
client := http.Client{}
for k, v := range cases {
@ -212,7 +212,7 @@ func TestNotFound(t *testing.T) {
}
func TestVersion(t *testing.T) {
handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", false)
handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version")
server := httptest.NewServer(handler)
client := http.Client{}
@ -241,7 +241,7 @@ func TestSimpleList(t *testing.T) {
storage := map[string]RESTStorage{}
simpleStorage := SimpleRESTStorage{}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple")
@ -260,7 +260,7 @@ func TestErrorList(t *testing.T) {
errors: map[string]error{"list": fmt.Errorf("test Error")},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple")
@ -284,7 +284,7 @@ func TestNonEmptyList(t *testing.T) {
},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple")
@ -319,7 +319,7 @@ func TestGet(t *testing.T) {
},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple/id")
@ -340,7 +340,7 @@ func TestGetMissing(t *testing.T) {
errors: map[string]error{"get": apierrs.NewNotFound("simple", "id")},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
resp, err := http.Get(server.URL + "/prefix/version/simple/id")
@ -358,7 +358,7 @@ func TestDelete(t *testing.T) {
simpleStorage := SimpleRESTStorage{}
ID := "id"
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
client := http.Client{}
@ -380,7 +380,7 @@ func TestDeleteMissing(t *testing.T) {
errors: map[string]error{"delete": apierrs.NewNotFound("simple", ID)},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
client := http.Client{}
@ -400,7 +400,7 @@ func TestUpdate(t *testing.T) {
simpleStorage := SimpleRESTStorage{}
ID := "id"
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
item := &Simple{
@ -430,7 +430,7 @@ func TestUpdateMissing(t *testing.T) {
errors: map[string]error{"update": apierrs.NewNotFound("simple", ID)},
}
storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix/version", false)
handler := Handle(storage, codec, "/prefix/version")
server := httptest.NewServer(handler)
item := &Simple{
@ -457,7 +457,7 @@ func TestCreate(t *testing.T) {
simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
handler.(*defaultAPIServer).group.handler.asyncOpWait = 0
server := httptest.NewServer(handler)
client := http.Client{}
@ -498,7 +498,7 @@ func TestCreateNotFound(t *testing.T) {
// See https://github.com/GoogleCloudPlatform/kubernetes/pull/486#discussion_r15037092.
errors: map[string]error{"create": apierrs.NewNotFound("simple", "id")},
},
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
server := httptest.NewServer(handler)
client := http.Client{}
@ -540,7 +540,7 @@ func TestSyncCreate(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": &storage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
server := httptest.NewServer(handler)
client := http.Client{}
@ -609,7 +609,7 @@ func TestAsyncDelayReturnsError(t *testing.T) {
return nil, apierrs.NewAlreadyExists("foo", "bar")
},
}
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false)
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version")
handler.(*defaultAPIServer).group.handler.asyncOpWait = time.Millisecond / 2
server := httptest.NewServer(handler)
@ -627,7 +627,7 @@ func TestAsyncCreateError(t *testing.T) {
return nil, apierrs.NewAlreadyExists("foo", "bar")
},
}
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version", false)
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix/version")
handler.(*defaultAPIServer).group.handler.asyncOpWait = 0
server := httptest.NewServer(handler)
@ -721,7 +721,7 @@ func TestSyncCreateTimeout(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": &storage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
server := httptest.NewServer(handler)
simple := &Simple{Name: "foo"}
@ -732,8 +732,8 @@ func TestSyncCreateTimeout(t *testing.T) {
}
}
func TestEnableCORS(t *testing.T) {
handler := Handle(map[string]RESTStorage{}, codec, "/prefix/version", true)
func TestCORSAllowedOrigin(t *testing.T) {
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true")
server := httptest.NewServer(handler)
client := http.Client{}
@ -764,3 +764,36 @@ func TestEnableCORS(t *testing.T) {
t.Errorf("Expected Access-Control-Allow-Methods header to be set")
}
}
func TestCORSUnallowedOrigin(t *testing.T) {
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []string{"example.com"}, nil, nil, "true")
server := httptest.NewServer(handler)
client := http.Client{}
request, err := http.NewRequest("GET", server.URL+"/version", nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
request.Header.Set("Origin", "not-allowed.com")
response, err := client.Do(request)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if response.Header.Get("Access-Control-Allow-Origin") != "" {
t.Errorf("Expected Access-Control-Allow-Origin header to not be set")
}
if response.Header.Get("Access-Control-Allow-Credentials") != "" {
t.Errorf("Expected Access-Control-Allow-Credentials header to not be set")
}
if response.Header.Get("Access-Control-Allow-Headers") != "" {
t.Errorf("Expected Access-Control-Allow-Headers header to not be set")
}
if response.Header.Get("Access-Control-Allow-Methods") != "" {
t.Errorf("Expected Access-Control-Allow-Methods header to not be set")
}
}

View File

@ -24,6 +24,7 @@ import (
"strings"
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/golang/glog"
)
@ -58,13 +59,25 @@ func RecoverPanics(handler http.Handler) http.Handler {
// For a more detailed implementation use https://github.com/martini-contrib/cors
// or implement CORS at your proxy layer
// Pass nil for allowedMethods and allowedHeaders to use the defaults
func CORS(handler http.Handler, allowedOriginPatterns []string, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler {
func CORS(handler http.Handler, allowedOriginPatterns util.StringList, allowedMethods []string, allowedHeaders []string, allowCredentials string) http.Handler {
// Compile the regular expressions once upfront
allowedOriginRegexps := []*regexp.Regexp{}
for _, allowedOrigin := range allowedOriginPatterns {
allowedOriginRegexp, err := regexp.Compile(allowedOrigin)
if err != nil {
glog.Fatalf("Invalid CORS allowed origin regexp: %v", err)
}
allowedOriginRegexps = append(allowedOriginRegexps, allowedOriginRegexp)
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if origin != "" {
allowed := false
for _, pattern := range allowedOriginPatterns {
allowed, _ = regexp.MatchString(pattern, origin)
for _, pattern := range allowedOriginRegexps {
if allowed = pattern.MatchString(origin); allowed {
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)

View File

@ -127,7 +127,7 @@ func TestApiServerMinionProxy(t *testing.T) {
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.URL.Path))
}))
server := httptest.NewServer(Handle(nil, nil, "/prefix", false))
server := httptest.NewServer(Handle(nil, nil, "/prefix"))
proxy, _ := url.Parse(proxyServer.URL)
resp, err := http.Get(fmt.Sprintf("%s/proxy/minion/%s%s", server.URL, proxy.Host, "/test"))
if err != nil {

View File

@ -107,7 +107,7 @@ func TestOperationsList(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
handler.(*defaultAPIServer).group.handler.asyncOpWait = 0
server := httptest.NewServer(handler)
client := http.Client{}
@ -163,7 +163,7 @@ func TestOpGet(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
handler.(*defaultAPIServer).group.handler.asyncOpWait = 0
server := httptest.NewServer(handler)
client := http.Client{}

View File

@ -30,7 +30,7 @@ func TestRedirect(t *testing.T) {
}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
server := httptest.NewServer(handler)
dontFollow := errors.New("don't follow")

View File

@ -44,7 +44,7 @@ func TestWatchWebsocket(t *testing.T) {
_ = ResourceWatcher(simpleStorage) // Give compile error if this doesn't work.
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
server := httptest.NewServer(handler)
dest, _ := url.Parse(server.URL)
@ -90,7 +90,7 @@ func TestWatchHTTP(t *testing.T) {
simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
server := httptest.NewServer(handler)
client := http.Client{}
@ -147,7 +147,7 @@ func TestWatchParamParsing(t *testing.T) {
simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{
"foo": simpleStorage,
}, codec, "/prefix/version", false)
}, codec, "/prefix/version")
server := httptest.NewServer(handler)
dest, _ := url.Parse(server.URL)

View File

@ -41,7 +41,7 @@ func TestClient(t *testing.T) {
})
storage, codec := m.API_v1beta1()
s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/", false))
s := httptest.NewServer(apiserver.Handle(storage, codec, "/api/v1beta1/"))
client := client.NewOrDie(s.URL, nil)