From 6b7675c693a319f2b3ef591d68bbe587483566b9 Mon Sep 17 00:00:00 2001 From: Alvaro Aleman Date: Mon, 24 Feb 2020 15:15:04 +0100 Subject: [PATCH] Utilerrors.Aggregate: Allow using with errors.Is() Kubernetes-commit: 212190e25e18600bbca2eb5c77aa3fe5bcc55af1 --- tools/clientcmd/validation.go | 32 +++++++++- tools/clientcmd/validation_test.go | 99 ++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 1 deletion(-) diff --git a/tools/clientcmd/validation.go b/tools/clientcmd/validation.go index 02dd1130..afe6f80b 100644 --- a/tools/clientcmd/validation.go +++ b/tools/clientcmd/validation.go @@ -86,11 +86,41 @@ func (e errConfigurationInvalid) Error() string { return fmt.Sprintf("invalid configuration: %v", utilerrors.NewAggregate(e).Error()) } -// Errors implements the AggregateError interface +// Errors implements the utilerrors.Aggregate interface func (e errConfigurationInvalid) Errors() []error { return e } +// Is implements the utilerrors.Aggregate interface +func (e errConfigurationInvalid) Is(target error) bool { + return e.visit(func(err error) bool { + return errors.Is(err, target) + }) +} + +func (e errConfigurationInvalid) visit(f func(err error) bool) bool { + for _, err := range e { + switch err := err.(type) { + case errConfigurationInvalid: + if match := err.visit(f); match { + return match + } + case utilerrors.Aggregate: + for _, nestedErr := range err.Errors() { + if match := f(nestedErr); match { + return match + } + } + default: + if match := f(err); match { + return match + } + } + } + + return false +} + // IsConfigurationInvalid returns true if the provided error indicates the configuration is invalid. func IsConfigurationInvalid(err error) bool { switch err.(type) { diff --git a/tools/clientcmd/validation_test.go b/tools/clientcmd/validation_test.go index 1680eeb1..caf67054 100644 --- a/tools/clientcmd/validation_test.go +++ b/tools/clientcmd/validation_test.go @@ -17,6 +17,8 @@ limitations under the License. package clientcmd import ( + "errors" + "fmt" "io/ioutil" "os" "strings" @@ -569,3 +571,100 @@ func (c configValidationTest) testAuthInfo(authInfoName string, t *testing.T) { } } } + +type alwaysMatchingError struct{} + +func (_ alwaysMatchingError) Error() string { + return "error" +} + +func (_ alwaysMatchingError) Is(_ error) bool { + return true +} + +type someError struct{ msg string } + +func (se someError) Error() string { + if se.msg != "" { + return se.msg + } + return "err" +} + +func TestErrConfigurationInvalidWithErrorsIs(t *testing.T) { + testCases := []struct { + name string + err error + matchAgainst error + expectMatch bool + }{ + { + name: "no match", + err: errConfigurationInvalid{errors.New("my-error"), errors.New("my-other-error")}, + matchAgainst: fmt.Errorf("no entry %s", "here"), + }, + { + name: "match via .Is()", + err: errConfigurationInvalid{errors.New("forbidden"), alwaysMatchingError{}}, + matchAgainst: errors.New("unauthorized"), + expectMatch: true, + }, + { + name: "match via equality", + err: errConfigurationInvalid{errors.New("err"), someError{}}, + matchAgainst: someError{}, + expectMatch: true, + }, + { + name: "match via nested aggregate", + err: errConfigurationInvalid{errors.New("closed today"), errConfigurationInvalid{errConfigurationInvalid{someError{}}}}, + matchAgainst: someError{}, + expectMatch: true, + }, + { + name: "match via wrapped aggregate", + err: fmt.Errorf("wrap: %w", errConfigurationInvalid{errors.New("err"), someError{}}), + matchAgainst: someError{}, + expectMatch: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := errors.Is(tc.err, tc.matchAgainst) + if result != tc.expectMatch { + t.Errorf("expected match: %t, got match: %t", tc.expectMatch, result) + } + }) + } +} + +type accessTrackingError struct { + wasAccessed bool +} + +func (accessTrackingError) Error() string { + return "err" +} + +func (ate *accessTrackingError) Is(_ error) bool { + ate.wasAccessed = true + return true +} + +var _ error = &accessTrackingError{} + +func TestErrConfigurationInvalidWithErrorsIsShortCircuitsOnFirstMatch(t *testing.T) { + errC := errConfigurationInvalid{&accessTrackingError{}, &accessTrackingError{}} + _ = errors.Is(errC, &accessTrackingError{}) + + var numAccessed int + for _, err := range errC { + if ate := err.(*accessTrackingError); ate.wasAccessed { + numAccessed++ + } + } + if numAccessed != 1 { + t.Errorf("expected exactly one error to get accessed, got %d", numAccessed) + } +}