Allow short-circuiting union auth on error

This commit is contained in:
Jordan Liggitt 2016-09-09 09:40:31 -04:00
parent 0f3baaad50
commit 174e454874
No known key found for this signature in database
GPG Key ID: 24E7ADF9A3B42012
2 changed files with 45 additions and 10 deletions

View File

@ -25,26 +25,40 @@ import (
) )
// unionAuthRequestHandler authenticates requests using a chain of authenticator.Requests // unionAuthRequestHandler authenticates requests using a chain of authenticator.Requests
type unionAuthRequestHandler []authenticator.Request type unionAuthRequestHandler struct {
// Handlers is a chain of request authenticators to delegate to
// New returns a request authenticator that validates credentials using a chain of authenticator.Request objects Handlers []authenticator.Request
func New(authRequestHandlers ...authenticator.Request) authenticator.Request { // FailOnError determines whether an error returns short-circuits the chain
return unionAuthRequestHandler(authRequestHandlers) FailOnError bool
} }
// AuthenticateRequest authenticates the request using a chain of authenticator.Request objects. The first // New returns a request authenticator that validates credentials using a chain of authenticator.Request objects.
// success returns that identity. Errors are only returned if no matches are found. // The entire chain is tried until one succeeds. If all fail, an aggregate error is returned.
func (authHandler unionAuthRequestHandler) AuthenticateRequest(req *http.Request) (user.Info, bool, error) { func New(authRequestHandlers ...authenticator.Request) authenticator.Request {
return &unionAuthRequestHandler{Handlers: authRequestHandlers, FailOnError: false}
}
// NewFailOnError returns a request authenticator that validates credentials using a chain of authenticator.Request objects.
// The first error short-circuits the chain.
func NewFailOnError(authRequestHandlers ...authenticator.Request) authenticator.Request {
return &unionAuthRequestHandler{Handlers: authRequestHandlers, FailOnError: true}
}
// AuthenticateRequest authenticates the request using a chain of authenticator.Request objects.
func (authHandler *unionAuthRequestHandler) AuthenticateRequest(req *http.Request) (user.Info, bool, error) {
var errlist []error var errlist []error
for _, currAuthRequestHandler := range authHandler { for _, currAuthRequestHandler := range authHandler.Handlers {
info, ok, err := currAuthRequestHandler.AuthenticateRequest(req) info, ok, err := currAuthRequestHandler.AuthenticateRequest(req)
if err != nil { if err != nil {
if authHandler.FailOnError {
return info, ok, err
}
errlist = append(errlist, err) errlist = append(errlist, err)
continue continue
} }
if ok { if ok {
return info, true, nil return info, ok, err
} }
} }

View File

@ -143,3 +143,24 @@ func TestAuthenticateRequestAdditiveErrors(t *testing.T) {
t.Errorf("Unexpectedly authenticated: %v", isAuthenticated) t.Errorf("Unexpectedly authenticated: %v", isAuthenticated)
} }
} }
func TestAuthenticateRequestFailEarly(t *testing.T) {
handler1 := &mockAuthRequestHandler{err: errors.New("first")}
handler2 := &mockAuthRequestHandler{err: errors.New("second")}
authRequestHandler := NewFailOnError(handler1, handler2)
req, _ := http.NewRequest("GET", "http://example.org", nil)
_, isAuthenticated, err := authRequestHandler.AuthenticateRequest(req)
if err == nil {
t.Errorf("Expected an error")
}
if !strings.Contains(err.Error(), "first") {
t.Errorf("Expected error containing %v, got %v", "first", err)
}
if strings.Contains(err.Error(), "second") {
t.Errorf("Did not expect second error, got %v", err)
}
if isAuthenticated {
t.Errorf("Unexpectedly authenticated: %v", isAuthenticated)
}
}