From 174e454874d8e5fa41e2a04643be611912c52715 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Fri, 9 Sep 2016 09:40:31 -0400 Subject: [PATCH] Allow short-circuiting union auth on error --- .../auth/authenticator/request/union/union.go | 34 +++++++++++++------ .../request/union/unionauth_test.go | 21 ++++++++++++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/plugin/pkg/auth/authenticator/request/union/union.go b/plugin/pkg/auth/authenticator/request/union/union.go index 5c34b4d5fcd..dc2640a613c 100644 --- a/plugin/pkg/auth/authenticator/request/union/union.go +++ b/plugin/pkg/auth/authenticator/request/union/union.go @@ -25,26 +25,40 @@ import ( ) // unionAuthRequestHandler authenticates requests using a chain of authenticator.Requests -type unionAuthRequestHandler []authenticator.Request - -// New returns a request authenticator that validates credentials using a chain of authenticator.Request objects -func New(authRequestHandlers ...authenticator.Request) authenticator.Request { - return unionAuthRequestHandler(authRequestHandlers) +type unionAuthRequestHandler struct { + // Handlers is a chain of request authenticators to delegate to + Handlers []authenticator.Request + // FailOnError determines whether an error returns short-circuits the chain + FailOnError bool } -// AuthenticateRequest authenticates the request using a chain of authenticator.Request objects. The first -// success returns that identity. Errors are only returned if no matches are found. -func (authHandler unionAuthRequestHandler) AuthenticateRequest(req *http.Request) (user.Info, bool, error) { +// New returns a request authenticator that validates credentials using a chain of authenticator.Request objects. +// The entire chain is tried until one succeeds. If all fail, an aggregate error is returned. +func New(authRequestHandlers ...authenticator.Request) authenticator.Request { + return &unionAuthRequestHandler{Handlers: authRequestHandlers, FailOnError: false} +} + +// NewFailOnError returns a request authenticator that validates credentials using a chain of authenticator.Request objects. +// The first error short-circuits the chain. +func NewFailOnError(authRequestHandlers ...authenticator.Request) authenticator.Request { + return &unionAuthRequestHandler{Handlers: authRequestHandlers, FailOnError: true} +} + +// AuthenticateRequest authenticates the request using a chain of authenticator.Request objects. +func (authHandler *unionAuthRequestHandler) AuthenticateRequest(req *http.Request) (user.Info, bool, error) { var errlist []error - for _, currAuthRequestHandler := range authHandler { + for _, currAuthRequestHandler := range authHandler.Handlers { info, ok, err := currAuthRequestHandler.AuthenticateRequest(req) if err != nil { + if authHandler.FailOnError { + return info, ok, err + } errlist = append(errlist, err) continue } if ok { - return info, true, nil + return info, ok, err } } diff --git a/plugin/pkg/auth/authenticator/request/union/unionauth_test.go b/plugin/pkg/auth/authenticator/request/union/unionauth_test.go index 62fa0c0da59..8792339501f 100644 --- a/plugin/pkg/auth/authenticator/request/union/unionauth_test.go +++ b/plugin/pkg/auth/authenticator/request/union/unionauth_test.go @@ -143,3 +143,24 @@ func TestAuthenticateRequestAdditiveErrors(t *testing.T) { t.Errorf("Unexpectedly authenticated: %v", isAuthenticated) } } + +func TestAuthenticateRequestFailEarly(t *testing.T) { + handler1 := &mockAuthRequestHandler{err: errors.New("first")} + handler2 := &mockAuthRequestHandler{err: errors.New("second")} + authRequestHandler := NewFailOnError(handler1, handler2) + req, _ := http.NewRequest("GET", "http://example.org", nil) + + _, isAuthenticated, err := authRequestHandler.AuthenticateRequest(req) + if err == nil { + t.Errorf("Expected an error") + } + if !strings.Contains(err.Error(), "first") { + t.Errorf("Expected error containing %v, got %v", "first", err) + } + if strings.Contains(err.Error(), "second") { + t.Errorf("Did not expect second error, got %v", err) + } + if isAuthenticated { + t.Errorf("Unexpectedly authenticated: %v", isAuthenticated) + } +}