Switch LogOf from panicking when logger is missing to creating logger with the defaults.

Update CORS tests to a table-based test and cover more cases.
This commit is contained in:
Jessica Forrester 2014-09-09 17:05:18 -04:00
parent becf6ca4e7
commit 0cac1c5f79
7 changed files with 85 additions and 92 deletions

View File

@ -24,6 +24,7 @@ import (
"net/http" "net/http"
"os" "os"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/apiserver" "github.com/GoogleCloudPlatform/kubernetes/pkg/apiserver"
@ -141,7 +142,7 @@ func main() {
if len(corsAllowedOriginList) > 0 { if len(corsAllowedOriginList) > 0 {
allowedOriginRegexps, err := util.CompileRegexps(corsAllowedOriginList) allowedOriginRegexps, err := util.CompileRegexps(corsAllowedOriginList)
if err != nil { if err != nil {
glog.Fatalf("Invalid CORS allowed origin: %v", err) glog.Fatalf("Invalid CORS allowed origin, --cors_allowed_origins flag was set to %v - %v", strings.Join(corsAllowedOriginList, ","), err)
} }
handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true") handler = apiserver.CORS(handler, allowedOriginRegexps, nil, nil, "true")
} }

View File

@ -25,7 +25,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"regexp"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -35,6 +34,7 @@ import (
apierrs "github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors" apierrs "github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors"
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels" "github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/version" "github.com/GoogleCloudPlatform/kubernetes/pkg/version"
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch" "github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
) )
@ -733,68 +733,72 @@ func TestSyncCreateTimeout(t *testing.T) {
} }
} }
func TestCORSAllowedOrigin(t *testing.T) { func TestCORSAllowedOrigins(t *testing.T) {
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true") table := []struct {
server := httptest.NewServer(handler) allowedOrigins util.StringList
client := http.Client{} origin string
allowed bool
request, err := http.NewRequest("GET", server.URL+"/version", nil) }{
if err != nil { {[]string{}, "example.com", false},
t.Errorf("unexpected error: %v", err) {[]string{"example.com"}, "example.com", true},
} {[]string{"example.com"}, "not-allowed.com", false},
request.Header.Set("Origin", "example.com") {[]string{"not-matching.com", "example.com"}, "example.com", true},
{[]string{".*"}, "example.com", true},
response, err := client.Do(request)
if err != nil {
t.Errorf("unexpected error: %v", err)
} }
if !reflect.DeepEqual(response.Header.Get("Access-Control-Allow-Origin"), "example.com") { for _, item := range table {
t.Errorf("Expected %#v, Got %#v", response.Header.Get("Access-Control-Allow-Origin"), "example.com") allowedOriginRegexps, err := util.CompileRegexps(item.allowedOrigins)
} if err != nil {
t.Errorf("unexpected error: %v", err)
}
if response.Header.Get("Access-Control-Allow-Credentials") == "" { handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), allowedOriginRegexps, nil, nil, "true")
t.Errorf("Expected Access-Control-Allow-Credentials header to be set") server := httptest.NewServer(handler)
} client := http.Client{}
if response.Header.Get("Access-Control-Allow-Headers") == "" { request, err := http.NewRequest("GET", server.URL+"/version", nil)
t.Errorf("Expected Access-Control-Allow-Headers header to be set") if err != nil {
} t.Errorf("unexpected error: %v", err)
}
request.Header.Set("Origin", item.origin)
if response.Header.Get("Access-Control-Allow-Methods") == "" { response, err := client.Do(request)
t.Errorf("Expected Access-Control-Allow-Methods header to be set") if err != nil {
} t.Errorf("unexpected error: %v", err)
} }
func TestCORSUnallowedOrigin(t *testing.T) { if item.allowed {
handler := CORS(Handle(map[string]RESTStorage{}, codec, "/prefix/version"), []*regexp.Regexp{regexp.MustCompile("example.com")}, nil, nil, "true") if !reflect.DeepEqual(item.origin, response.Header.Get("Access-Control-Allow-Origin")) {
server := httptest.NewServer(handler) t.Errorf("Expected %#v, Got %#v", item.origin, response.Header.Get("Access-Control-Allow-Origin"))
client := http.Client{} }
request, err := http.NewRequest("GET", server.URL+"/version", nil) if response.Header.Get("Access-Control-Allow-Credentials") == "" {
if err != nil { t.Errorf("Expected Access-Control-Allow-Credentials header to be set")
t.Errorf("unexpected error: %v", err) }
}
request.Header.Set("Origin", "not-allowed.com") if response.Header.Get("Access-Control-Allow-Headers") == "" {
t.Errorf("Expected Access-Control-Allow-Headers header to be set")
response, err := client.Do(request) }
if err != nil {
t.Errorf("unexpected error: %v", err) if response.Header.Get("Access-Control-Allow-Methods") == "" {
} t.Errorf("Expected Access-Control-Allow-Methods header to be set")
}
if response.Header.Get("Access-Control-Allow-Origin") != "" { } else {
t.Errorf("Expected Access-Control-Allow-Origin header to not be set") if response.Header.Get("Access-Control-Allow-Origin") != "" {
} t.Errorf("Expected Access-Control-Allow-Origin header to not be set")
}
if response.Header.Get("Access-Control-Allow-Credentials") != "" {
t.Errorf("Expected Access-Control-Allow-Credentials header to not be set") if response.Header.Get("Access-Control-Allow-Credentials") != "" {
} t.Errorf("Expected Access-Control-Allow-Credentials header to not be set")
}
if response.Header.Get("Access-Control-Allow-Headers") != "" {
t.Errorf("Expected Access-Control-Allow-Headers header to not be set") if response.Header.Get("Access-Control-Allow-Headers") != "" {
} t.Errorf("Expected Access-Control-Allow-Headers header to not be set")
}
if response.Header.Get("Access-Control-Allow-Methods") != "" {
t.Errorf("Expected Access-Control-Allow-Methods header to not be set") if response.Header.Get("Access-Control-Allow-Methods") != "" {
t.Errorf("Expected Access-Control-Allow-Methods header to not be set")
}
}
} }
} }

View File

@ -38,14 +38,14 @@ func (r *RedirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
id := parts[1] id := parts[1]
storage, ok := r.storage[resourceName] storage, ok := r.storage[resourceName]
if !ok { if !ok {
httplog.LogOf(w).Addf("'%v' has no storage object", resourceName) httplog.LogOf(req, w).Addf("'%v' has no storage object", resourceName)
notFound(w, req) notFound(w, req)
return return
} }
redirector, ok := storage.(Redirector) redirector, ok := storage.(Redirector)
if !ok { if !ok {
httplog.LogOf(w).Addf("'%v' is not a redirector", resourceName) httplog.LogOf(req, w).Addf("'%v' is not a redirector", resourceName)
notFound(w, req) notFound(w, req)
return return
} }

View File

@ -42,7 +42,7 @@ func (h *RESTHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
storage := h.storage[parts[0]] storage := h.storage[parts[0]]
if storage == nil { if storage == nil {
httplog.FindOrCreateLogOf(req, &w).Addf("'%v' has no storage object", parts[0]) httplog.LogOf(req, w).Addf("'%v' has no storage object", parts[0])
notFound(w, req) notFound(w, req)
return return
} }
@ -114,7 +114,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt
return return
} }
op := h.createOperation(out, sync, timeout) op := h.createOperation(out, sync, timeout)
h.finishReq(op, w) h.finishReq(op, req, w)
case "DELETE": case "DELETE":
if len(parts) != 2 { if len(parts) != 2 {
@ -127,7 +127,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt
return return
} }
op := h.createOperation(out, sync, timeout) op := h.createOperation(out, sync, timeout)
h.finishReq(op, w) h.finishReq(op, req, w)
case "PUT": case "PUT":
if len(parts) != 2 { if len(parts) != 2 {
@ -151,7 +151,7 @@ func (h *RESTHandler) handleRESTStorage(parts []string, req *http.Request, w htt
return return
} }
op := h.createOperation(out, sync, timeout) op := h.createOperation(out, sync, timeout)
h.finishReq(op, w) h.finishReq(op, req, w)
default: default:
notFound(w, req) notFound(w, req)
@ -171,7 +171,7 @@ func (h *RESTHandler) createOperation(out <-chan runtime.Object, sync bool, time
// finishReq finishes up a request, waiting until the operation finishes or, after a timeout, creating an // finishReq finishes up a request, waiting until the operation finishes or, after a timeout, creating an
// Operation to receive the result and returning its ID down the writer. // Operation to receive the result and returning its ID down the writer.
func (h *RESTHandler) finishReq(op *Operation, w http.ResponseWriter) { func (h *RESTHandler) finishReq(op *Operation, req *http.Request, w http.ResponseWriter) {
obj, complete := op.StatusOrResult() obj, complete := op.StatusOrResult()
if complete { if complete {
status := http.StatusOK status := http.StatusOK

View File

@ -127,7 +127,7 @@ func (w *WatchServer) HandleWS(ws *websocket.Conn) {
// ServeHTTP serves a series of JSON encoded events via straight HTTP with // ServeHTTP serves a series of JSON encoded events via straight HTTP with
// Transfer-Encoding: chunked. // Transfer-Encoding: chunked.
func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (self *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
loggedW := httplog.FindOrCreateLogOf(req, &w) loggedW := httplog.LogOf(req, w)
w = httplog.Unlogged(w) w = httplog.Unlogged(w)
cn, ok := w.(http.CloseNotifier) cn, ok := w.(http.CloseNotifier)

View File

@ -86,22 +86,17 @@ func NewLogged(req *http.Request, w *http.ResponseWriter) *respLogger {
return rl return rl
} }
// LogOf returns the logger hiding in w. Panics if there isn't such a logger, // LogOf returns the logger hiding in w. If there is not an existing logger
// because NewLogged() must have been previously called for the log to work. // then one will be created because NewLogged() must have been previously
func LogOf(w http.ResponseWriter) *respLogger { // called for the log to work.
func LogOf(req *http.Request, w http.ResponseWriter) *respLogger {
if _, exists := w.(*respLogger); !exists {
NewLogged(req, &w)
}
if rl, ok := w.(*respLogger); ok { if rl, ok := w.(*respLogger); ok {
return rl return rl
} }
panic("Logger not installed yet!") panic("Unable to find or create the logger!")
}
// Returns the existing logger hiding in w. If there is not an existing logger
// then one will be created.
func FindOrCreateLogOf(req *http.Request, w *http.ResponseWriter) *respLogger {
if _, exists := (*w).(*respLogger); !exists {
NewLogged(req, w)
}
return LogOf(*w)
} }
// Unlogged returns the original ResponseWriter, or w if it is not our inserted logger. // Unlogged returns the original ResponseWriter, or w if it is not our inserted logger.

View File

@ -91,19 +91,12 @@ func TestLogOf(t *testing.T) {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
} }
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
var want *respLogger
if makeLogger { if makeLogger {
want = NewLogged(req, &w) NewLogged(req, &w)
} else {
defer func() {
if r := recover(); r == nil {
t.Errorf("Expected LogOf to panic")
}
}()
} }
got := LogOf(w) got := reflect.TypeOf(*LogOf(r, w)).String()
if want != got { if got != "httplog.respLogger" {
t.Errorf("Expected %v, got %v", want, got) t.Errorf("Expected %v, got %v", "httplog.respLogger", got)
} }
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()