Put user in context, map requests to context above resthandler layer

This commit is contained in:
Jordan Liggitt 2015-02-11 17:09:25 -05:00
parent ec66e5147e
commit 083ce268e0
14 changed files with 290 additions and 146 deletions

115
pkg/api/requestcontext.go Normal file
View File

@ -0,0 +1,115 @@
/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package api
import (
"errors"
"net/http"
"sync"
)
// RequestContextMapper keeps track of the context associated with a particular request
type RequestContextMapper interface {
// Get returns the context associated with the given request (if any), and true if the request has an associated context, and false if it does not.
Get(req *http.Request) (Context, bool)
// Update maps the request to the given context. If no context was previously associated with the request, an error is returned.
// Update should only be called with a descendant context of the previously associated context.
// Updating to an unrelated context may return an error in the future.
// The context associated with a request should only be updated by a limited set of callers.
// Valid examples include the authentication layer, or an audit/tracing layer.
Update(req *http.Request, context Context) error
}
type requestContextMap struct {
contexts map[*http.Request]Context
lock sync.Mutex
}
// NewRequestContextMapper returns a new RequestContextMapper.
// The returned mapper must be added as a request filter using NewRequestContextFilter.
func NewRequestContextMapper() RequestContextMapper {
return &requestContextMap{
contexts: make(map[*http.Request]Context),
}
}
// Get returns the context associated with the given request (if any), and true if the request has an associated context, and false if it does not.
// Get will only return a valid context when called from inside the filter chain set up by NewRequestContextFilter()
func (c *requestContextMap) Get(req *http.Request) (Context, bool) {
c.lock.Lock()
defer c.lock.Unlock()
context, ok := c.contexts[req]
return context, ok
}
// Update maps the request to the given context.
// If no context was previously associated with the request, an error is returned and the context is ignored.
func (c *requestContextMap) Update(req *http.Request, context Context) error {
c.lock.Lock()
defer c.lock.Unlock()
if _, ok := c.contexts[req]; !ok {
return errors.New("No context associated")
}
// TODO: ensure the new context is a descendant of the existing one
c.contexts[req] = context
return nil
}
// init maps the request to the given context and returns true if there was no context associated with the request already.
// if a context was already associated with the request, it ignores the given context and returns false.
// init is intentionally unexported to ensure that all init calls are paired with a remove after a request is handled
func (c *requestContextMap) init(req *http.Request, context Context) bool {
c.lock.Lock()
defer c.lock.Unlock()
if _, exists := c.contexts[req]; exists {
return false
}
c.contexts[req] = context
return true
}
// remove is intentionally unexported to ensure that the context is not removed until a request is handled
func (c *requestContextMap) remove(req *http.Request) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.contexts, req)
}
// NewRequestContextFilter ensures there is a Context object associated with the request before calling the passed handler.
// After the passed handler runs, the context is cleaned up.
func NewRequestContextFilter(mapper RequestContextMapper, handler http.Handler) (http.Handler, error) {
if mapper, ok := mapper.(*requestContextMap); ok {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if mapper.init(req, NewContext()) {
// If we were the ones to successfully initialize, pair with a remove
defer mapper.remove(req)
}
handler.ServeHTTP(w, req)
}), nil
} else {
return handler, errors.New("Unknown RequestContextMapper implementation.")
}
}
// IsEmpty returns true if there are no contexts registered, or an error if it could not be determined. Intended for use by tests.
func IsEmpty(requestsToContexts RequestContextMapper) (bool, error) {
if requestsToContexts, ok := requestsToContexts.(*requestContextMap); ok {
return len(requestsToContexts.contexts) == 0, nil
}
return true, errors.New("Unknown RequestContextMapper implementation")
}

View File

@ -62,8 +62,8 @@ func (a *APIInstaller) Install() (ws *restful.WebService, errors []error) {
linker: a.group.linker, linker: a.group.linker,
info: a.group.info, info: a.group.info,
}) })
redirectHandler := (&RedirectHandler{a.group.storage, a.group.codec, a.group.info}) redirectHandler := (&RedirectHandler{a.group.storage, a.group.codec, a.group.context, a.group.info})
proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.storage, a.group.codec, a.group.info}) proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.storage, a.group.codec, a.group.context, a.group.info})
for path, storage := range a.group.storage { for path, storage := range a.group.storage {
if err := a.registerResourceHandlers(path, storage, ws, watchHandler, redirectHandler, proxyHandler); err != nil { if err := a.registerResourceHandlers(path, storage, ws, watchHandler, redirectHandler, proxyHandler); err != nil {
@ -88,6 +88,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage RESTStorage
codec := a.group.codec codec := a.group.codec
admit := a.group.admit admit := a.group.admit
linker := a.group.linker linker := a.group.linker
context := a.group.context
resource := path resource := path
object := storage.New() object := storage.New()
@ -147,6 +148,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage RESTStorage
storageVerbs["Redirector"] = true storageVerbs["Redirector"] = true
} }
var ctxFn ContextFunc
var namespaceFn ResourceNamespaceFunc var namespaceFn ResourceNamespaceFunc
var nameFn ResourceNameFunc var nameFn ResourceNameFunc
var generateLinkFn linkFunc var generateLinkFn linkFunc
@ -154,6 +156,12 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage RESTStorage
linkFn := func(req *restful.Request, obj runtime.Object) error { linkFn := func(req *restful.Request, obj runtime.Object) error {
return setSelfLink(obj, req.Request, a.group.linker, generateLinkFn) return setSelfLink(obj, req.Request, a.group.linker, generateLinkFn)
} }
ctxFn = func(req *restful.Request) api.Context {
if ctx, ok := context.Get(req.Request); ok {
return ctx
}
return api.NewContext()
}
allowWatchList := storageVerbs["ResourceWatcher"] && storageVerbs["RESTLister"] // watching on lists is allowed only for kinds that support both watch and list. allowWatchList := storageVerbs["ResourceWatcher"] && storageVerbs["RESTLister"] // watching on lists is allowed only for kinds that support both watch and list.
scope := mapping.Scope scope := mapping.Scope
@ -324,7 +332,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage RESTStorage
m := monitorFilter(action.Verb, resource) m := monitorFilter(action.Verb, resource)
switch action.Verb { switch action.Verb {
case "GET": // Get a resource. case "GET": // Get a resource.
route := ws.GET(action.Path).To(GetResource(getter, nameFn, linkFn, codec)). route := ws.GET(action.Path).To(GetResource(getter, ctxFn, nameFn, linkFn, codec)).
Filter(m). Filter(m).
Doc("read the specified " + kind). Doc("read the specified " + kind).
Operation("read" + kind). Operation("read" + kind).
@ -332,7 +340,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage RESTStorage
addParams(route, action.Params) addParams(route, action.Params)
ws.Route(route) ws.Route(route)
case "LIST": // List all resources of a kind. case "LIST": // List all resources of a kind.
route := ws.GET(action.Path).To(ListResource(lister, namespaceFn, linkFn, codec)). route := ws.GET(action.Path).To(ListResource(lister, ctxFn, namespaceFn, linkFn, codec)).
Filter(m). Filter(m).
Doc("list objects of kind " + kind). Doc("list objects of kind " + kind).
Operation("list" + kind). Operation("list" + kind).
@ -340,7 +348,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage RESTStorage
addParams(route, action.Params) addParams(route, action.Params)
ws.Route(route) ws.Route(route)
case "PUT": // Update a resource. case "PUT": // Update a resource.
route := ws.PUT(action.Path).To(UpdateResource(updater, nameFn, objNameFn, linkFn, codec, resource, admit)). route := ws.PUT(action.Path).To(UpdateResource(updater, ctxFn, nameFn, objNameFn, linkFn, codec, resource, admit)).
Filter(m). Filter(m).
Doc("update the specified " + kind). Doc("update the specified " + kind).
Operation("update" + kind). Operation("update" + kind).
@ -348,7 +356,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage RESTStorage
addParams(route, action.Params) addParams(route, action.Params)
ws.Route(route) ws.Route(route)
case "POST": // Create a resource. case "POST": // Create a resource.
route := ws.POST(action.Path).To(CreateResource(creater, namespaceFn, linkFn, codec, resource, admit)). route := ws.POST(action.Path).To(CreateResource(creater, ctxFn, namespaceFn, linkFn, codec, resource, admit)).
Filter(m). Filter(m).
Doc("create a " + kind). Doc("create a " + kind).
Operation("create" + kind). Operation("create" + kind).
@ -356,7 +364,7 @@ func (a *APIInstaller) registerResourceHandlers(path string, storage RESTStorage
addParams(route, action.Params) addParams(route, action.Params)
ws.Route(route) ws.Route(route)
case "DELETE": // Delete a resource. case "DELETE": // Delete a resource.
route := ws.DELETE(action.Path).To(DeleteResource(deleter, nameFn, linkFn, codec, resource, kind, admit)). route := ws.DELETE(action.Path).To(DeleteResource(deleter, ctxFn, nameFn, linkFn, codec, resource, kind, admit)).
Filter(m). Filter(m).
Doc("delete a " + kind). Doc("delete a " + kind).
Operation("delete" + kind) Operation("delete" + kind)

View File

@ -98,9 +98,9 @@ type defaultAPIServer struct {
// as RESTful resources at prefix, serialized by codec, and also includes the support // as RESTful resources at prefix, serialized by codec, and also includes the support
// http resources. // http resources.
// Note: This method is used only in tests. // Note: This method is used only in tests.
func Handle(storage map[string]RESTStorage, codec runtime.Codec, root string, version string, linker runtime.SelfLinker, admissionControl admission.Interface, mapper meta.RESTMapper) http.Handler { func Handle(storage map[string]RESTStorage, codec runtime.Codec, root string, version string, linker runtime.SelfLinker, admissionControl admission.Interface, contextMapper api.RequestContextMapper, mapper meta.RESTMapper) http.Handler {
prefix := path.Join(root, version) prefix := path.Join(root, version)
group := NewAPIGroupVersion(storage, codec, root, prefix, linker, admissionControl, mapper) group := NewAPIGroupVersion(storage, codec, root, prefix, linker, admissionControl, contextMapper, mapper)
container := restful.NewContainer() container := restful.NewContainer()
container.Router(restful.CurlyRouter{}) container.Router(restful.CurlyRouter{})
mux := container.ServeMux mux := container.ServeMux
@ -121,6 +121,7 @@ type APIGroupVersion struct {
prefix string prefix string
linker runtime.SelfLinker linker runtime.SelfLinker
admit admission.Interface admit admission.Interface
context api.RequestContextMapper
mapper meta.RESTMapper mapper meta.RESTMapper
// TODO: put me into a cleaner interface // TODO: put me into a cleaner interface
info *APIRequestInfoResolver info *APIRequestInfoResolver
@ -131,13 +132,14 @@ type APIGroupVersion struct {
// This is a helper method for registering multiple sets of REST handlers under different // This is a helper method for registering multiple sets of REST handlers under different
// prefixes onto a server. // prefixes onto a server.
// TODO: add multitype codec serialization // TODO: add multitype codec serialization
func NewAPIGroupVersion(storage map[string]RESTStorage, codec runtime.Codec, root, prefix string, linker runtime.SelfLinker, admissionControl admission.Interface, mapper meta.RESTMapper) *APIGroupVersion { func NewAPIGroupVersion(storage map[string]RESTStorage, codec runtime.Codec, root, prefix string, linker runtime.SelfLinker, admissionControl admission.Interface, contextMapper api.RequestContextMapper, mapper meta.RESTMapper) *APIGroupVersion {
return &APIGroupVersion{ return &APIGroupVersion{
storage: storage, storage: storage,
codec: codec, codec: codec,
prefix: prefix, prefix: prefix,
linker: linker, linker: linker,
admit: admissionControl, admit: admissionControl,
context: contextMapper,
mapper: mapper, mapper: mapper,
info: &APIRequestInfoResolver{util.NewStringSet(root), latest.RESTMapper}, info: &APIRequestInfoResolver{util.NewStringSet(root), latest.RESTMapper},
} }

View File

@ -57,6 +57,7 @@ var versioner runtime.ResourceVersioner = accessor
var selfLinker runtime.SelfLinker = accessor var selfLinker runtime.SelfLinker = accessor
var mapper, namespaceMapper, legacyNamespaceMapper meta.RESTMapper // The mappers with namespace and with legacy namespace scopes. var mapper, namespaceMapper, legacyNamespaceMapper meta.RESTMapper // The mappers with namespace and with legacy namespace scopes.
var admissionControl admission.Interface var admissionControl admission.Interface
var requestContextMapper api.RequestContextMapper
func interfacesFor(version string) (*meta.VersionInterfaces, error) { func interfacesFor(version string) (*meta.VersionInterfaces, error) {
switch version { switch version {
@ -111,6 +112,7 @@ func init() {
legacyNamespaceMapper = legacyNsMapper legacyNamespaceMapper = legacyNsMapper
namespaceMapper = nsMapper namespaceMapper = nsMapper
admissionControl = admit.NewAlwaysAdmit() admissionControl = admit.NewAlwaysAdmit()
requestContextMapper = api.NewRequestContextMapper()
} }
type Simple struct { type Simple struct {
@ -283,7 +285,7 @@ func TestNotFound(t *testing.T) {
} }
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": &SimpleRESTStorage{}, "foo": &SimpleRESTStorage{},
}, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) }, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
client := http.Client{} client := http.Client{}
@ -335,7 +337,7 @@ func TestUnimplementedRESTStorage(t *testing.T) {
} }
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": UnimplementedRESTStorage{}, "foo": UnimplementedRESTStorage{},
}, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) }, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
client := http.Client{} client := http.Client{}
@ -360,7 +362,7 @@ func TestUnimplementedRESTStorage(t *testing.T) {
} }
func TestVersion(t *testing.T) { func TestVersion(t *testing.T) {
handler := Handle(map[string]RESTStorage{}, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(map[string]RESTStorage{}, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
client := http.Client{} client := http.Client{}
@ -395,7 +397,7 @@ func TestSimpleList(t *testing.T) {
namespace: "other", namespace: "other",
expectedSet: "/prefix/version/simple?namespace=other", expectedSet: "/prefix/version/simple?namespace=other",
} }
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -418,7 +420,7 @@ func TestErrorList(t *testing.T) {
errors: map[string]error{"list": fmt.Errorf("test Error")}, errors: map[string]error{"list": fmt.Errorf("test Error")},
} }
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -444,7 +446,7 @@ func TestNonEmptyList(t *testing.T) {
}, },
} }
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -492,7 +494,7 @@ func TestGet(t *testing.T) {
namespace: "default", namespace: "default",
} }
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -531,7 +533,7 @@ func TestGetAlternateSelfLink(t *testing.T) {
namespace: "test", namespace: "test",
} }
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, legacyNamespaceMapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, legacyNamespaceMapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -569,7 +571,7 @@ func TestGetNamespaceSelfLink(t *testing.T) {
namespace: "foo", namespace: "foo",
} }
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, namespaceMapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, namespaceMapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -598,7 +600,7 @@ func TestGetMissing(t *testing.T) {
errors: map[string]error{"get": apierrs.NewNotFound("simple", "id")}, errors: map[string]error{"get": apierrs.NewNotFound("simple", "id")},
} }
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -617,7 +619,7 @@ func TestDelete(t *testing.T) {
simpleStorage := SimpleRESTStorage{} simpleStorage := SimpleRESTStorage{}
ID := "id" ID := "id"
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -640,7 +642,7 @@ func TestDeleteInvokesAdmissionControl(t *testing.T) {
simpleStorage := SimpleRESTStorage{} simpleStorage := SimpleRESTStorage{}
ID := "id" ID := "id"
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, deny.NewAlwaysDeny(), mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, deny.NewAlwaysDeny(), requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -662,7 +664,7 @@ func TestDeleteMissing(t *testing.T) {
errors: map[string]error{"delete": apierrs.NewNotFound("simple", ID)}, errors: map[string]error{"delete": apierrs.NewNotFound("simple", ID)},
} }
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -689,7 +691,7 @@ func TestUpdate(t *testing.T) {
name: ID, name: ID,
namespace: api.NamespaceDefault, namespace: api.NamespaceDefault,
} }
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -726,7 +728,7 @@ func TestUpdateInvokesAdmissionControl(t *testing.T) {
simpleStorage := SimpleRESTStorage{} simpleStorage := SimpleRESTStorage{}
ID := "id" ID := "id"
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, deny.NewAlwaysDeny(), mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, deny.NewAlwaysDeny(), requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -759,7 +761,7 @@ func TestUpdateRequiresMatchingName(t *testing.T) {
simpleStorage := SimpleRESTStorage{} simpleStorage := SimpleRESTStorage{}
ID := "id" ID := "id"
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, deny.NewAlwaysDeny(), mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, deny.NewAlwaysDeny(), requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -788,7 +790,7 @@ func TestUpdateAllowsMissingNamespace(t *testing.T) {
simpleStorage := SimpleRESTStorage{} simpleStorage := SimpleRESTStorage{}
ID := "id" ID := "id"
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -820,7 +822,7 @@ func TestUpdatePreventsMismatchedNamespace(t *testing.T) {
simpleStorage := SimpleRESTStorage{} simpleStorage := SimpleRESTStorage{}
ID := "id" ID := "id"
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -855,7 +857,7 @@ func TestUpdateMissing(t *testing.T) {
errors: map[string]error{"update": apierrs.NewNotFound("simple", ID)}, errors: map[string]error{"update": apierrs.NewNotFound("simple", ID)},
} }
storage["simple"] = &simpleStorage storage["simple"] = &simpleStorage
handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(storage, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -889,7 +891,7 @@ func TestCreateNotFound(t *testing.T) {
// See https://github.com/GoogleCloudPlatform/kubernetes/pull/486#discussion_r15037092. // See https://github.com/GoogleCloudPlatform/kubernetes/pull/486#discussion_r15037092.
errors: map[string]error{"create": apierrs.NewNotFound("simple", "id")}, errors: map[string]error{"create": apierrs.NewNotFound("simple", "id")},
}, },
}, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) }, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
client := http.Client{} client := http.Client{}
@ -957,7 +959,7 @@ func TestCreate(t *testing.T) {
} }
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": &storage, "foo": &storage,
}, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) }, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
client := http.Client{} client := http.Client{}
@ -1015,7 +1017,7 @@ func TestCreateInvokesAdmissionControl(t *testing.T) {
} }
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": &storage, "foo": &storage,
}, codec, "/prefix", testVersion, selfLinker, deny.NewAlwaysDeny(), mapper) }, codec, "/prefix", testVersion, selfLinker, deny.NewAlwaysDeny(), requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
client := http.Client{} client := http.Client{}
@ -1075,7 +1077,7 @@ func TestDelayReturnsError(t *testing.T) {
return nil, apierrs.NewAlreadyExists("foo", "bar") return nil, apierrs.NewAlreadyExists("foo", "bar")
}, },
} }
handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) handler := Handle(map[string]RESTStorage{"foo": &storage}, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -1141,7 +1143,7 @@ func TestCreateTimeout(t *testing.T) {
} }
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": &storage, "foo": &storage,
}, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper) }, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -1173,7 +1175,7 @@ func TestCORSAllowedOrigins(t *testing.T) {
} }
handler := CORS( handler := CORS(
Handle(map[string]RESTStorage{}, codec, "/prefix", testVersion, selfLinker, admissionControl, mapper), Handle(map[string]RESTStorage{}, codec, "/prefix", testVersion, selfLinker, admissionControl, requestContextMapper, mapper),
allowedOriginRegexps, nil, nil, "true", allowedOriginRegexps, nil, nil, "true",
) )
server := httptest.NewServer(handler) server := httptest.NewServer(handler)

View File

@ -27,7 +27,6 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors" "github.com/GoogleCloudPlatform/kubernetes/pkg/api/errors"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api/meta" "github.com/GoogleCloudPlatform/kubernetes/pkg/api/meta"
"github.com/GoogleCloudPlatform/kubernetes/pkg/auth/authorizer" "github.com/GoogleCloudPlatform/kubernetes/pkg/auth/authorizer"
authhandlers "github.com/GoogleCloudPlatform/kubernetes/pkg/auth/handlers"
"github.com/GoogleCloudPlatform/kubernetes/pkg/httplog" "github.com/GoogleCloudPlatform/kubernetes/pkg/httplog"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util" "github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/golang/glog" "github.com/golang/glog"
@ -154,21 +153,24 @@ type RequestAttributeGetter interface {
} }
type requestAttributeGetter struct { type requestAttributeGetter struct {
userContexts authhandlers.RequestContext requestContextMapper api.RequestContextMapper
apiRequestInfoResolver *APIRequestInfoResolver apiRequestInfoResolver *APIRequestInfoResolver
} }
// NewAttributeGetter returns an object which implements the RequestAttributeGetter interface. // NewAttributeGetter returns an object which implements the RequestAttributeGetter interface.
func NewRequestAttributeGetter(userContexts authhandlers.RequestContext, restMapper meta.RESTMapper, apiRoots ...string) RequestAttributeGetter { func NewRequestAttributeGetter(requestContextMapper api.RequestContextMapper, restMapper meta.RESTMapper, apiRoots ...string) RequestAttributeGetter {
return &requestAttributeGetter{userContexts, &APIRequestInfoResolver{util.NewStringSet(apiRoots...), restMapper}} return &requestAttributeGetter{requestContextMapper, &APIRequestInfoResolver{util.NewStringSet(apiRoots...), restMapper}}
} }
func (r *requestAttributeGetter) GetAttribs(req *http.Request) authorizer.Attributes { func (r *requestAttributeGetter) GetAttribs(req *http.Request) authorizer.Attributes {
attribs := authorizer.AttributesRecord{} attribs := authorizer.AttributesRecord{}
user, ok := r.userContexts.Get(req) ctx, ok := r.requestContextMapper.Get(req)
if ok { if ok {
attribs.User = user user, ok := api.UserFrom(ctx)
if ok {
attribs.User = user
}
} }
attribs.ReadOnly = IsReadOnlyReq(*req) attribs.ReadOnly = IsReadOnlyReq(*req)

View File

@ -78,6 +78,7 @@ type ProxyHandler struct {
prefix string prefix string
storage map[string]RESTStorage storage map[string]RESTStorage
codec runtime.Codec codec runtime.Codec
context api.RequestContextMapper
apiRequestInfoResolver *APIRequestInfoResolver apiRequestInfoResolver *APIRequestInfoResolver
} }
@ -97,7 +98,11 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
verb = requestInfo.Verb verb = requestInfo.Verb
namespace, resource, parts := requestInfo.Namespace, requestInfo.Resource, requestInfo.Parts namespace, resource, parts := requestInfo.Namespace, requestInfo.Resource, requestInfo.Parts
ctx := api.WithNamespace(api.NewContext(), namespace) ctx, ok := r.context.Get(req)
if !ok {
ctx = api.NewContext()
}
ctx = api.WithNamespace(ctx, namespace)
if len(parts) < 2 { if len(parts) < 2 {
notFound(w, req) notFound(w, req)
httpCode = http.StatusNotFound httpCode = http.StatusNotFound

View File

@ -281,12 +281,12 @@ func TestProxy(t *testing.T) {
namespaceHandler := Handle(map[string]RESTStorage{ namespaceHandler := Handle(map[string]RESTStorage{
"foo": simpleStorage, "foo": simpleStorage,
}, codec, "/prefix", "version", selfLinker, admissionControl, namespaceMapper) }, codec, "/prefix", "version", selfLinker, admissionControl, requestContextMapper, namespaceMapper)
namespaceServer := httptest.NewServer(namespaceHandler) namespaceServer := httptest.NewServer(namespaceHandler)
defer namespaceServer.Close() defer namespaceServer.Close()
legacyNamespaceHandler := Handle(map[string]RESTStorage{ legacyNamespaceHandler := Handle(map[string]RESTStorage{
"foo": simpleStorage, "foo": simpleStorage,
}, codec, "/prefix", "version", selfLinker, admissionControl, legacyNamespaceMapper) }, codec, "/prefix", "version", selfLinker, admissionControl, requestContextMapper, legacyNamespaceMapper)
legacyNamespaceServer := httptest.NewServer(legacyNamespaceHandler) legacyNamespaceServer := httptest.NewServer(legacyNamespaceHandler)
defer legacyNamespaceServer.Close() defer legacyNamespaceServer.Close()

View File

@ -29,6 +29,7 @@ import (
type RedirectHandler struct { type RedirectHandler struct {
storage map[string]RESTStorage storage map[string]RESTStorage
codec runtime.Codec codec runtime.Codec
context api.RequestContextMapper
apiRequestInfoResolver *APIRequestInfoResolver apiRequestInfoResolver *APIRequestInfoResolver
} }
@ -47,7 +48,11 @@ func (r *RedirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
verb = requestInfo.Verb verb = requestInfo.Verb
resource, parts := requestInfo.Resource, requestInfo.Parts resource, parts := requestInfo.Resource, requestInfo.Parts
ctx := api.WithNamespace(api.NewContext(), requestInfo.Namespace) ctx, ok := r.context.Get(req)
if !ok {
ctx = api.NewContext()
}
ctx = api.WithNamespace(ctx, requestInfo.Namespace)
// redirection requires /resource/resourceName path parts // redirection requires /resource/resourceName path parts
if len(parts) != 2 || req.Method != "GET" { if len(parts) != 2 || req.Method != "GET" {

View File

@ -31,7 +31,7 @@ func TestRedirect(t *testing.T) {
} }
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": simpleStorage, "foo": simpleStorage,
}, codec, "/prefix", "version", selfLinker, admissionControl, mapper) }, codec, "/prefix", "version", selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -84,7 +84,7 @@ func TestRedirectWithNamespaces(t *testing.T) {
} }
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": simpleStorage, "foo": simpleStorage,
}, codec, "/prefix", "version", selfLinker, admissionControl, namespaceMapper) }, codec, "/prefix", "version", selfLinker, admissionControl, requestContextMapper, namespaceMapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()

View File

@ -29,6 +29,9 @@ import (
"github.com/emicklei/go-restful" "github.com/emicklei/go-restful"
) )
// ContextFunc returns a Context given a request - a context must be returned
type ContextFunc func(req *restful.Request) api.Context
// ResourceNameFunc returns a name (and optional namespace) given a request - if no name is present // ResourceNameFunc returns a name (and optional namespace) given a request - if no name is present
// an error must be returned. // an error must be returned.
type ResourceNameFunc func(req *restful.Request) (namespace, name string, err error) type ResourceNameFunc func(req *restful.Request) (namespace, name string, err error)
@ -45,7 +48,7 @@ type ResourceNamespaceFunc func(req *restful.Request) (namespace string, err err
type LinkResourceFunc func(req *restful.Request, obj runtime.Object) error type LinkResourceFunc func(req *restful.Request, obj runtime.Object) error
// GetResource returns a function that handles retrieving a single resource from a RESTStorage object. // GetResource returns a function that handles retrieving a single resource from a RESTStorage object.
func GetResource(r RESTGetter, nameFn ResourceNameFunc, linkFn LinkResourceFunc, codec runtime.Codec) restful.RouteFunction { func GetResource(r RESTGetter, ctxFn ContextFunc, nameFn ResourceNameFunc, linkFn LinkResourceFunc, codec runtime.Codec) restful.RouteFunction {
return func(req *restful.Request, res *restful.Response) { return func(req *restful.Request, res *restful.Response) {
w := res.ResponseWriter w := res.ResponseWriter
namespace, name, err := nameFn(req) namespace, name, err := nameFn(req)
@ -53,7 +56,7 @@ func GetResource(r RESTGetter, nameFn ResourceNameFunc, linkFn LinkResourceFunc,
notFound(w, req.Request) notFound(w, req.Request)
return return
} }
ctx := api.NewContext() ctx := ctxFn(req)
if len(namespace) > 0 { if len(namespace) > 0 {
ctx = api.WithNamespace(ctx, namespace) ctx = api.WithNamespace(ctx, namespace)
} }
@ -71,7 +74,7 @@ func GetResource(r RESTGetter, nameFn ResourceNameFunc, linkFn LinkResourceFunc,
} }
// ListResource returns a function that handles retrieving a list of resources from a RESTStorage object. // ListResource returns a function that handles retrieving a list of resources from a RESTStorage object.
func ListResource(r RESTLister, namespaceFn ResourceNamespaceFunc, linkFn LinkResourceFunc, codec runtime.Codec) restful.RouteFunction { func ListResource(r RESTLister, ctxFn ContextFunc, namespaceFn ResourceNamespaceFunc, linkFn LinkResourceFunc, codec runtime.Codec) restful.RouteFunction {
return func(req *restful.Request, res *restful.Response) { return func(req *restful.Request, res *restful.Response) {
w := res.ResponseWriter w := res.ResponseWriter
@ -80,7 +83,7 @@ func ListResource(r RESTLister, namespaceFn ResourceNamespaceFunc, linkFn LinkRe
notFound(w, req.Request) notFound(w, req.Request)
return return
} }
ctx := api.NewContext() ctx := ctxFn(req)
if len(namespace) > 0 { if len(namespace) > 0 {
ctx = api.WithNamespace(ctx, namespace) ctx = api.WithNamespace(ctx, namespace)
} }
@ -109,7 +112,7 @@ func ListResource(r RESTLister, namespaceFn ResourceNamespaceFunc, linkFn LinkRe
} }
// CreateResource returns a function that will handle a resource creation. // CreateResource returns a function that will handle a resource creation.
func CreateResource(r RESTCreater, namespaceFn ResourceNamespaceFunc, linkFn LinkResourceFunc, codec runtime.Codec, resource string, admit admission.Interface) restful.RouteFunction { func CreateResource(r RESTCreater, ctxFn ContextFunc, namespaceFn ResourceNamespaceFunc, linkFn LinkResourceFunc, codec runtime.Codec, resource string, admit admission.Interface) restful.RouteFunction {
return func(req *restful.Request, res *restful.Response) { return func(req *restful.Request, res *restful.Response) {
w := res.ResponseWriter w := res.ResponseWriter
@ -121,7 +124,7 @@ func CreateResource(r RESTCreater, namespaceFn ResourceNamespaceFunc, linkFn Lin
notFound(w, req.Request) notFound(w, req.Request)
return return
} }
ctx := api.NewContext() ctx := ctxFn(req)
if len(namespace) > 0 { if len(namespace) > 0 {
ctx = api.WithNamespace(ctx, namespace) ctx = api.WithNamespace(ctx, namespace)
} }
@ -162,7 +165,7 @@ func CreateResource(r RESTCreater, namespaceFn ResourceNamespaceFunc, linkFn Lin
} }
// UpdateResource returns a function that will handle a resource update // UpdateResource returns a function that will handle a resource update
func UpdateResource(r RESTUpdater, nameFn ResourceNameFunc, objNameFunc ObjectNameFunc, linkFn LinkResourceFunc, codec runtime.Codec, resource string, admit admission.Interface) restful.RouteFunction { func UpdateResource(r RESTUpdater, ctxFn ContextFunc, nameFn ResourceNameFunc, objNameFunc ObjectNameFunc, linkFn LinkResourceFunc, codec runtime.Codec, resource string, admit admission.Interface) restful.RouteFunction {
return func(req *restful.Request, res *restful.Response) { return func(req *restful.Request, res *restful.Response) {
w := res.ResponseWriter w := res.ResponseWriter
@ -174,7 +177,7 @@ func UpdateResource(r RESTUpdater, nameFn ResourceNameFunc, objNameFunc ObjectNa
notFound(w, req.Request) notFound(w, req.Request)
return return
} }
ctx := api.NewContext() ctx := ctxFn(req)
if len(namespace) > 0 { if len(namespace) > 0 {
ctx = api.WithNamespace(ctx, namespace) ctx = api.WithNamespace(ctx, namespace)
} }
@ -238,7 +241,7 @@ func UpdateResource(r RESTUpdater, nameFn ResourceNameFunc, objNameFunc ObjectNa
} }
// DeleteResource returns a function that will handle a resource deletion // DeleteResource returns a function that will handle a resource deletion
func DeleteResource(r RESTDeleter, nameFn ResourceNameFunc, linkFn LinkResourceFunc, codec runtime.Codec, resource, kind string, admit admission.Interface) restful.RouteFunction { func DeleteResource(r RESTDeleter, ctxFn ContextFunc, nameFn ResourceNameFunc, linkFn LinkResourceFunc, codec runtime.Codec, resource, kind string, admit admission.Interface) restful.RouteFunction {
return func(req *restful.Request, res *restful.Response) { return func(req *restful.Request, res *restful.Response) {
w := res.ResponseWriter w := res.ResponseWriter
@ -250,7 +253,7 @@ func DeleteResource(r RESTDeleter, nameFn ResourceNameFunc, linkFn LinkResourceF
notFound(w, req.Request) notFound(w, req.Request)
return return
} }
ctx := api.NewContext() ctx := ctxFn(req)
if len(namespace) > 0 { if len(namespace) > 0 {
ctx = api.WithNamespace(ctx, namespace) ctx = api.WithNamespace(ctx, namespace)
} }

View File

@ -50,7 +50,7 @@ func TestWatchWebsocket(t *testing.T) {
_ = ResourceWatcher(simpleStorage) // Give compile error if this doesn't work. _ = ResourceWatcher(simpleStorage) // Give compile error if this doesn't work.
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": simpleStorage, "foo": simpleStorage,
}, codec, "/api", "version", selfLinker, admissionControl, mapper) }, codec, "/api", "version", selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -104,7 +104,7 @@ func TestWatchHTTP(t *testing.T) {
simpleStorage := &SimpleRESTStorage{} simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": simpleStorage, "foo": simpleStorage,
}, codec, "/api", "version", selfLinker, admissionControl, mapper) }, codec, "/api", "version", selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
client := http.Client{} client := http.Client{}
@ -167,7 +167,7 @@ func TestWatchParamParsing(t *testing.T) {
simpleStorage := &SimpleRESTStorage{} simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": simpleStorage, "foo": simpleStorage,
}, codec, "/api", "version", selfLinker, admissionControl, mapper) }, codec, "/api", "version", selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
@ -239,7 +239,7 @@ func TestWatchProtocolSelection(t *testing.T) {
simpleStorage := &SimpleRESTStorage{} simpleStorage := &SimpleRESTStorage{}
handler := Handle(map[string]RESTStorage{ handler := Handle(map[string]RESTStorage{
"foo": simpleStorage, "foo": simpleStorage,
}, codec, "/api", "version", selfLinker, admissionControl, mapper) }, codec, "/api", "version", selfLinker, admissionControl, requestContextMapper, mapper)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()
defer server.CloseClientConnections() defer server.CloseClientConnections()

View File

@ -18,39 +18,35 @@ package handlers
import ( import (
"net/http" "net/http"
"sync"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/auth/authenticator" "github.com/GoogleCloudPlatform/kubernetes/pkg/auth/authenticator"
"github.com/GoogleCloudPlatform/kubernetes/pkg/auth/user"
"github.com/golang/glog" "github.com/golang/glog"
) )
// RequestContext is the interface used to associate a user with an http Request.
type RequestContext interface {
Set(*http.Request, user.Info)
Get(req *http.Request) (user.Info, bool)
Remove(*http.Request)
}
// NewRequestAuthenticator creates an http handler that tries to authenticate the given request as a user, and then // NewRequestAuthenticator creates an http handler that tries to authenticate the given request as a user, and then
// stores any such user found onto the provided context for the request. If authentication fails or returns an error // stores any such user found onto the provided context for the request. If authentication fails or returns an error
// the failed handler is used. On success, handler is invoked to serve the request. // the failed handler is used. On success, handler is invoked to serve the request.
func NewRequestAuthenticator(context RequestContext, auth authenticator.Request, failed http.Handler, handler http.Handler) http.Handler { func NewRequestAuthenticator(mapper api.RequestContextMapper, auth authenticator.Request, failed http.Handler, handler http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return api.NewRequestContextFilter(
user, ok, err := auth.AuthenticateRequest(req) mapper,
if err != nil || !ok { http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if err != nil { user, ok, err := auth.AuthenticateRequest(req)
glog.Errorf("Unable to authenticate the request due to an error: %v", err) if err != nil || !ok {
if err != nil {
glog.Errorf("Unable to authenticate the request due to an error: %v", err)
}
failed.ServeHTTP(w, req)
return
} }
failed.ServeHTTP(w, req)
return
}
context.Set(req, user) if ctx, ok := mapper.Get(req); ok {
defer context.Remove(req) mapper.Update(req, api.WithUser(ctx, user))
}
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
}) }),
)
} }
var Unauthorized http.HandlerFunc = unauthorized var Unauthorized http.HandlerFunc = unauthorized
@ -59,38 +55,3 @@ var Unauthorized http.HandlerFunc = unauthorized
func unauthorized(w http.ResponseWriter, req *http.Request) { func unauthorized(w http.ResponseWriter, req *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
} }
// UserRequestContext allows different levels of a call stack to store/retrieve info about the
// current user associated with an http.Request.
type UserRequestContext struct {
requests map[*http.Request]user.Info
lock sync.Mutex
}
// NewUserRequestContext provides a map for storing and retrieving users associated with requests.
// Be sure to pair each `context.Set(req, user)` call with a `defer context.Remove(req)` call or
// you will leak requests. It implements the RequestContext interface.
func NewUserRequestContext() *UserRequestContext {
return &UserRequestContext{
requests: make(map[*http.Request]user.Info),
}
}
func (c *UserRequestContext) Get(req *http.Request) (user.Info, bool) {
c.lock.Lock()
defer c.lock.Unlock()
user, ok := c.requests[req]
return user, ok
}
func (c *UserRequestContext) Set(req *http.Request, user user.Info) {
c.lock.Lock()
defer c.lock.Unlock()
c.requests[req] = user
}
func (c *UserRequestContext) Remove(req *http.Request) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.requests, req)
}

View File

@ -22,15 +22,16 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/auth/authenticator" "github.com/GoogleCloudPlatform/kubernetes/pkg/auth/authenticator"
"github.com/GoogleCloudPlatform/kubernetes/pkg/auth/user" "github.com/GoogleCloudPlatform/kubernetes/pkg/auth/user"
) )
func TestAuthenticateRequest(t *testing.T) { func TestAuthenticateRequest(t *testing.T) {
success := make(chan struct{}) success := make(chan struct{})
context := NewUserRequestContext() contextMapper := api.NewRequestContextMapper()
auth := NewRequestAuthenticator( auth, err := NewRequestAuthenticator(
context, contextMapper,
authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) { authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) {
return &user.DefaultInfo{Name: "user"}, true, nil return &user.DefaultInfo{Name: "user"}, true, nil
}), }),
@ -38,8 +39,13 @@ func TestAuthenticateRequest(t *testing.T) {
t.Errorf("unexpected call to failed") t.Errorf("unexpected call to failed")
}), }),
http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
if user, ok := context.Get(req); user == nil || !ok { ctx, ok := contextMapper.Get(req)
t.Errorf("no user stored on context: %#v", context) if ctx == nil || !ok {
t.Errorf("no context stored on contextMapper: %#v", contextMapper)
}
user, ok := api.UserFrom(ctx)
if user == nil || !ok {
t.Errorf("no user stored in context: %#v", ctx)
} }
close(success) close(success)
}), }),
@ -48,16 +54,20 @@ func TestAuthenticateRequest(t *testing.T) {
auth.ServeHTTP(httptest.NewRecorder(), &http.Request{}) auth.ServeHTTP(httptest.NewRecorder(), &http.Request{})
<-success <-success
if len(context.requests) > 0 { empty, err := api.IsEmpty(contextMapper)
t.Errorf("context should have no stored requests: %v", context) if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !empty {
t.Fatalf("contextMapper should have no stored requests: %v", contextMapper)
} }
} }
func TestAuthenticateRequestFailed(t *testing.T) { func TestAuthenticateRequestFailed(t *testing.T) {
failed := make(chan struct{}) failed := make(chan struct{})
context := NewUserRequestContext() contextMapper := api.NewRequestContextMapper()
auth := NewRequestAuthenticator( auth, err := NewRequestAuthenticator(
context, contextMapper,
authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) { authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) {
return nil, false, nil return nil, false, nil
}), }),
@ -72,16 +82,20 @@ func TestAuthenticateRequestFailed(t *testing.T) {
auth.ServeHTTP(httptest.NewRecorder(), &http.Request{}) auth.ServeHTTP(httptest.NewRecorder(), &http.Request{})
<-failed <-failed
if len(context.requests) > 0 { empty, err := api.IsEmpty(contextMapper)
t.Errorf("context should have no stored requests: %v", context) if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !empty {
t.Fatalf("contextMapper should have no stored requests: %v", contextMapper)
} }
} }
func TestAuthenticateRequestError(t *testing.T) { func TestAuthenticateRequestError(t *testing.T) {
failed := make(chan struct{}) failed := make(chan struct{})
context := NewUserRequestContext() contextMapper := api.NewRequestContextMapper()
auth := NewRequestAuthenticator( auth, err := NewRequestAuthenticator(
context, contextMapper,
authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) { authenticator.RequestFunc(func(req *http.Request) (user.Info, bool, error) {
return nil, false, errors.New("failure") return nil, false, errors.New("failure")
}), }),
@ -96,7 +110,11 @@ func TestAuthenticateRequestError(t *testing.T) {
auth.ServeHTTP(httptest.NewRecorder(), &http.Request{}) auth.ServeHTTP(httptest.NewRecorder(), &http.Request{})
<-failed <-failed
if len(context.requests) > 0 { empty, err := api.IsEmpty(contextMapper)
t.Errorf("context should have no stored requests: %v", context) if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !empty {
t.Fatalf("contextMapper should have no stored requests: %v", contextMapper)
} }
} }

View File

@ -89,6 +89,9 @@ type Config struct {
AdmissionControl admission.Interface AdmissionControl admission.Interface
MasterServiceNamespace string MasterServiceNamespace string
// Map requests to contexts. Exported so downstream consumers can provider their own mappers
RequestContextMapper api.RequestContextMapper
// If specified, all web services will be registered into this container // If specified, all web services will be registered into this container
RestfulContainer *restful.Container RestfulContainer *restful.Container
@ -143,6 +146,7 @@ type Master struct {
admissionControl admission.Interface admissionControl admission.Interface
masterCount int masterCount int
v1beta3 bool v1beta3 bool
requestContextMapper api.RequestContextMapper
publicIP net.IP publicIP net.IP
publicReadOnlyPort int publicReadOnlyPort int
@ -225,6 +229,9 @@ func setDefaults(c *Config) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
} }
} }
if c.RequestContextMapper == nil {
c.RequestContextMapper = api.NewRequestContextMapper()
}
} }
// New returns a new instance of Master from the given config. // New returns a new instance of Master from the given config.
@ -292,6 +299,7 @@ func New(c *Config) *Master {
authorizer: c.Authorizer, authorizer: c.Authorizer,
admissionControl: c.AdmissionControl, admissionControl: c.AdmissionControl,
v1beta3: c.EnableV1Beta3, v1beta3: c.EnableV1Beta3,
requestContextMapper: c.RequestContextMapper,
cacheTimeout: c.CacheTimeout, cacheTimeout: c.CacheTimeout,
@ -367,8 +375,6 @@ func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter)
// init initializes master. // init initializes master.
func (m *Master) init(c *Config) { func (m *Master) init(c *Config) {
var userContexts = handlers.NewUserRequestContext()
var authenticator = c.Authenticator
nodeRESTStorage := minion.NewREST(m.minionRegistry) nodeRESTStorage := minion.NewREST(m.minionRegistry)
podCache := NewPodCache( podCache := NewPodCache(
@ -453,12 +459,16 @@ func (m *Master) init(c *Config) {
m.InsecureHandler = handler m.InsecureHandler = handler
attributeGetter := apiserver.NewRequestAttributeGetter(userContexts, latest.RESTMapper, "api") attributeGetter := apiserver.NewRequestAttributeGetter(m.requestContextMapper, latest.RESTMapper, "api")
handler = apiserver.WithAuthorizationCheck(handler, attributeGetter, m.authorizer) handler = apiserver.WithAuthorizationCheck(handler, attributeGetter, m.authorizer)
// Install Authenticator // Install Authenticator
if authenticator != nil { if c.Authenticator != nil {
handler = handlers.NewRequestAuthenticator(userContexts, authenticator, handlers.Unauthorized, handler) authenticatedHandler, err := handlers.NewRequestAuthenticator(m.requestContextMapper, c.Authenticator, handlers.Unauthorized, handler)
if err != nil {
glog.Fatalf("Could not initialize authenticator: %v", err)
}
handler = authenticatedHandler
} }
// Install root web services // Install root web services
@ -471,6 +481,19 @@ func (m *Master) init(c *Config) {
m.InstallSwaggerAPI() m.InstallSwaggerAPI()
} }
// After all wrapping is done, put a context filter around both handlers
if handler, err := api.NewRequestContextFilter(m.requestContextMapper, m.Handler); err != nil {
glog.Fatalf("Could not initialize request context filter: %v", err)
} else {
m.Handler = handler
}
if handler, err := api.NewRequestContextFilter(m.requestContextMapper, m.InsecureHandler); err != nil {
glog.Fatalf("Could not initialize request context filter: %v", err)
} else {
m.InsecureHandler = handler
}
// TODO: Attempt clean shutdown? // TODO: Attempt clean shutdown?
m.masterServices.Start() m.masterServices.Start()
} }
@ -530,25 +553,25 @@ func (m *Master) getServersToValidate(c *Config) map[string]apiserver.Server {
} }
// api_v1beta1 returns the resources and codec for API version v1beta1. // api_v1beta1 returns the resources and codec for API version v1beta1.
func (m *Master) api_v1beta1() (map[string]apiserver.RESTStorage, runtime.Codec, string, string, runtime.SelfLinker, admission.Interface, meta.RESTMapper) { func (m *Master) api_v1beta1() (map[string]apiserver.RESTStorage, runtime.Codec, string, string, runtime.SelfLinker, admission.Interface, api.RequestContextMapper, meta.RESTMapper) {
storage := make(map[string]apiserver.RESTStorage) storage := make(map[string]apiserver.RESTStorage)
for k, v := range m.storage { for k, v := range m.storage {
storage[k] = v storage[k] = v
} }
return storage, v1beta1.Codec, "api", "/api/v1beta1", latest.SelfLinker, m.admissionControl, latest.RESTMapper return storage, v1beta1.Codec, "api", "/api/v1beta1", latest.SelfLinker, m.admissionControl, m.requestContextMapper, latest.RESTMapper
} }
// api_v1beta2 returns the resources and codec for API version v1beta2. // api_v1beta2 returns the resources and codec for API version v1beta2.
func (m *Master) api_v1beta2() (map[string]apiserver.RESTStorage, runtime.Codec, string, string, runtime.SelfLinker, admission.Interface, meta.RESTMapper) { func (m *Master) api_v1beta2() (map[string]apiserver.RESTStorage, runtime.Codec, string, string, runtime.SelfLinker, admission.Interface, api.RequestContextMapper, meta.RESTMapper) {
storage := make(map[string]apiserver.RESTStorage) storage := make(map[string]apiserver.RESTStorage)
for k, v := range m.storage { for k, v := range m.storage {
storage[k] = v storage[k] = v
} }
return storage, v1beta2.Codec, "api", "/api/v1beta2", latest.SelfLinker, m.admissionControl, latest.RESTMapper return storage, v1beta2.Codec, "api", "/api/v1beta2", latest.SelfLinker, m.admissionControl, m.requestContextMapper, latest.RESTMapper
} }
// api_v1beta3 returns the resources and codec for API version v1beta3. // api_v1beta3 returns the resources and codec for API version v1beta3.
func (m *Master) api_v1beta3() (map[string]apiserver.RESTStorage, runtime.Codec, string, string, runtime.SelfLinker, admission.Interface, meta.RESTMapper) { func (m *Master) api_v1beta3() (map[string]apiserver.RESTStorage, runtime.Codec, string, string, runtime.SelfLinker, admission.Interface, api.RequestContextMapper, meta.RESTMapper) {
storage := make(map[string]apiserver.RESTStorage) storage := make(map[string]apiserver.RESTStorage)
for k, v := range m.storage { for k, v := range m.storage {
if k == "minions" { if k == "minions" {
@ -556,5 +579,5 @@ func (m *Master) api_v1beta3() (map[string]apiserver.RESTStorage, runtime.Codec,
} }
storage[strings.ToLower(k)] = v storage[strings.ToLower(k)] = v
} }
return storage, v1beta3.Codec, "api", "/api/v1beta3", latest.SelfLinker, m.admissionControl, latest.RESTMapper return storage, v1beta3.Codec, "api", "/api/v1beta3", latest.SelfLinker, m.admissionControl, m.requestContextMapper, latest.RESTMapper
} }