diff --git a/cmd/apiserver/apiserver.go b/cmd/apiserver/apiserver.go index 43553e88808..1becccec4eb 100644 --- a/cmd/apiserver/apiserver.go +++ b/cmd/apiserver/apiserver.go @@ -24,6 +24,7 @@ import ( "net/http" "os" "strconv" + "strings" "time" "github.com/GoogleCloudPlatform/kubernetes/pkg/apiserver" @@ -141,7 +142,7 @@ func main() { if len(corsAllowedOriginList) > 0 { allowedOriginRegexps, err := util.CompileRegexps(corsAllowedOriginList) if err != nil { - glog.Fatalf("Invalid CORS allowed origin: %v", err) + glog.Fatalf("Invalid CORS allowed origin, --cors_allowed_origins flag was set to %v - %v", strings.Join(corsAllowedOriginList, ","), err) } handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true") } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 19e412bb3ae..59dceff11ec 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -25,7 +25,6 @@ import ( "net/http" "net/http/httptest" "reflect" - "regexp" "strings" "sync" "testing" @@ -35,6 +34,7 @@ import ( apierrs "github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors" "github.com/GoogleCloudPlatform/kubernetes/pkg/labels" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" + "github.com/GoogleCloudPlatform/kubernetes/pkg/util" "github.com/GoogleCloudPlatform/kubernetes/pkg/version" "github.com/GoogleCloudPlatform/kubernetes/pkg/watch" ) @@ -733,68 +733,72 @@ func TestSyncCreateTimeout(t *testing.T) { } } -func TestCORSAllowedOrigin(t *testing.T) { - handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true") - server := httptest.NewServer(handler) - client := http.Client{} - - request, err := http.NewRequest("GET", server.URL+"/version", nil) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - request.Header.Set("Origin", "example.com") - - response, err := client.Do(request) - if err != nil { - t.Errorf("unexpected error: %v", err) +func TestCORSAllowedOrigins(t *testing.T) { + table := []struct { + allowedOrigins util.StringList + origin string + allowed bool + }{ + {[]string{}, "example.com", false}, + {[]string{"example.com"}, "example.com", true}, + {[]string{"example.com"}, "not-allowed.com", false}, + {[]string{"not-matching.com", "example.com"}, "example.com", true}, + {[]string{".*"}, "example.com", true}, } - if !reflect.DeepEqual(response.Header.Get("Access-Control-Allow-Origin"), "example.com") { - t.Errorf("Expected %#v, Got %#v", response.Header.Get("Access-Control-Allow-Origin"), "example.com") - } + for _, item := range table { + allowedOriginRegexps, err := util.CompileRegexps(item.allowedOrigins) + if err != nil { + t.Errorf("unexpected error: %v", err) + } - if response.Header.Get("Access-Control-Allow-Credentials") == "" { - t.Errorf("Expected Access-Control-Allow-Credentials header to be set") - } + handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), allowedOriginRegexps, nil, nil, "true") + server := httptest.NewServer(handler) + client := http.Client{} - if response.Header.Get("Access-Control-Allow-Headers") == "" { - t.Errorf("Expected Access-Control-Allow-Headers header to be set") - } + request, err := http.NewRequest("GET", server.URL+"/version", nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + request.Header.Set("Origin", item.origin) - if response.Header.Get("Access-Control-Allow-Methods") == "" { - t.Errorf("Expected Access-Control-Allow-Methods header to be set") - } -} - -func TestCORSUnallowedOrigin(t *testing.T) { - handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true") - server := httptest.NewServer(handler) - client := http.Client{} - - request, err := http.NewRequest("GET", server.URL+"/version", nil) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - request.Header.Set("Origin", "not-allowed.com") - - response, err := client.Do(request) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if response.Header.Get("Access-Control-Allow-Origin") != "" { - t.Errorf("Expected Access-Control-Allow-Origin header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Credentials") != "" { - t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Headers") != "" { - t.Errorf("Expected Access-Control-Allow-Headers header to not be set") - } - - if response.Header.Get("Access-Control-Allow-Methods") != "" { - t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + response, err := client.Do(request) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if item.allowed { + if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) { + t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin")) + } + + if response.Header.Get("Access-Control-Allow-Credentials") == "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") == "" { + t.Errorf("Expected Access-Control-Allow-Headers header to be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") == "" { + t.Errorf("Expected Access-Control-Allow-Methods header to be set") + } + } else { + if response.Header.Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Credentials") != "" { + t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Headers") != "" { + t.Errorf("Expected Access-Control-Allow-Headers header to not be set") + } + + if response.Header.Get("Access-Control-Allow-Methods") != "" { + t.Errorf("Expected Access-Control-Allow-Methods header to not be set") + } + } } } diff --git a/pkg/apiserver/redirect.go b/pkg/apiserver/redirect.go index 57aff02d3d1..d5175865404 100644 --- a/pkg/apiserver/redirect.go +++ b/pkg/apiserver/redirect.go @@ -38,14 +38,14 @@ func (r *RedirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { id := parts[1] storage, ok := r.storage[resourceName] if !ok { - httplog.LogOf(w).Addf("'%v' has no storage object", resourceName) + httplog.LogOf(req, w).Addf("'%v' has no storage object", resourceName) notFound(w, req) return } redirector, ok := storage.(Redirector) if !ok { - httplog.LogOf(w).Addf("'%v' is not a redirector", resourceName) + httplog.LogOf(req, w).Addf("'%v' is not a redirector", resourceName) notFound(w, req) return } diff --git a/pkg/apiserver/resthandler.go b/pkg/apiserver/resthandler.go index 8f5657a5d8f..4b6977b9e8d 100644 --- a/pkg/apiserver/resthandler.go +++ b/pkg/apiserver/resthandler.go @@ -42,7 +42,7 @@ func (h *RESTHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } storage := h.storage[parts[0]] if storage == nil { - httplog.FindOrCreateLogOf(req, &w).Addf("'%v' has no storage object", parts[0]) + httplog.LogOf(req, w).Addf("'%v' has no storage object", parts[0]) notFound(w, req) return } @@ -114,7 +114,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) case "DELETE": if len(parts) != 2 { @@ -127,7 +127,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) case "PUT": if len(parts) != 2 { @@ -151,7 +151,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt return } op := h.createOperation(out, sync, timeout) - h.finishReq(op, w) + h.finishReq(op, req, w) default: notFound(w, req) @@ -171,7 +171,7 @@ func (h *RESTHandler) createOperation(out <-chan runtime.Object, sync bool, time // finishReq finishes up a request, waiting until the operation finishes or, after a timeout, creating an // Operation to receive the result and returning its ID down the writer. -func (h *RESTHandler) finishReq(op *Operation, w http.ResponseWriter) { +func (h *RESTHandler) finishReq(op *Operation, req *http.Request, w http.ResponseWriter) { obj, complete := op.StatusOrResult() if complete { status := http.StatusOK diff --git a/pkg/apiserver/watch.go b/pkg/apiserver/watch.go index 60ff30c489e..771682087da 100644 --- a/pkg/apiserver/watch.go +++ b/pkg/apiserver/watch.go @@ -127,7 +127,7 @@ func (w *WatchServer) HandleWS(ws *websocket.Conn) { // ServeHTTP serves a series of JSON encoded events via straight HTTP with // Transfer-Encoding: chunked. func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - loggedW := httplog.FindOrCreateLogOf(req, &w) + loggedW := httplog.LogOf(req, w) w = httplog.Unlogged(w) cn, ok := w.(http.CloseNotifier) diff --git a/pkg/httplog/log.go b/pkg/httplog/log.go index d9d887b9dee..912625b7eb1 100644 --- a/pkg/httplog/log.go +++ b/pkg/httplog/log.go @@ -86,22 +86,17 @@ func NewLogged(req *http.Request, w *http.ResponseWriter) *respLogger { return rl } -// LogOf returns the logger hiding in w. Panics if there isn't such a logger, -// because NewLogged() must have been previously called for the log to work. -func LogOf(w http.ResponseWriter) *respLogger { +// LogOf returns the logger hiding in w. If there is not an existing logger +// then one will be created because NewLogged() must have been previously +// called for the log to work. +func LogOf(req *http.Request, w http.ResponseWriter) *respLogger { + if _, exists := w.(*respLogger); !exists { + NewLogged(req, &w) + } if rl, ok := w.(*respLogger); ok { return rl } - panic("Logger not installed yet!") -} - -// Returns the existing logger hiding in w. If there is not an existing logger -// then one will be created. -func FindOrCreateLogOf(req *http.Request, w *http.ResponseWriter) *respLogger { - if _, exists := (*w).(*respLogger); !exists { - NewLogged(req, w) - } - return LogOf(*w) + panic("Unable to find or create the logger!") } // Unlogged returns the original ResponseWriter, or w if it is not our inserted logger. diff --git a/pkg/httplog/log_test.go b/pkg/httplog/log_test.go index e287c8b11dd..0ae8f53cde6 100644 --- a/pkg/httplog/log_test.go +++ b/pkg/httplog/log_test.go @@ -91,19 +91,12 @@ func TestLogOf(t *testing.T) { t.Errorf("Unexpected error: %v", err) } handler := func(w http.ResponseWriter, r *http.Request) { - var want *respLogger if makeLogger { - want = NewLogged(req, &w) - } else { - defer func() { - if r := recover(); r == nil { - t.Errorf("Expected LogOf to panic") - } - }() + NewLogged(req, &w) } - got := LogOf(w) - if want != got { - t.Errorf("Expected %v, got %v", want, got) + got := reflect.TypeOf(*LogOf(r, w)).String() + if got != "httplog.respLogger" { + t.Errorf("Expected %v, got %v", "httplog.respLogger", got) } } w := httptest.NewRecorder()