From 212190e25e18600bbca2eb5c77aa3fe5bcc55af1 Mon Sep 17 00:00:00 2001 From: Alvaro Aleman Date: Mon, 24 Feb 2020 15:15:04 +0100 Subject: [PATCH] Utilerrors.Aggregate: Allow using with errors.Is() --- .../apimachinery/pkg/util/errors/errors.go | 32 ++++-- .../pkg/util/errors/errors_test.go | 98 ++++++++++++++++++ .../client-go/tools/clientcmd/validation.go | 32 +++++- .../tools/clientcmd/validation_test.go | 99 +++++++++++++++++++ 4 files changed, 254 insertions(+), 7 deletions(-) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/errors/errors.go b/staging/src/k8s.io/apimachinery/pkg/util/errors/errors.go index 62a73f34ebe..5bafc218e2f 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/errors/errors.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/errors/errors.go @@ -28,9 +28,14 @@ type MessageCountMap map[string]int // Aggregate represents an object that contains multiple errors, but does not // necessarily have singular semantic meaning. +// The aggregate can be used with `errors.Is()` to check for the occurrence of +// a specific error type. +// Errors.As() is not supported, because the caller presumably cares about a +// specific error of potentially multiple that match the given type. type Aggregate interface { error Errors() []error + Is(error) bool } // NewAggregate converts a slice of errors into an Aggregate interface, which @@ -71,16 +76,17 @@ func (agg aggregate) Error() string { } seenerrs := sets.NewString() result := "" - agg.visit(func(err error) { + agg.visit(func(err error) bool { msg := err.Error() if seenerrs.Has(msg) { - return + return false } seenerrs.Insert(msg) if len(seenerrs) > 1 { result += ", " } result += msg + return false }) if len(seenerrs) == 1 { return result @@ -88,19 +94,33 @@ func (agg aggregate) Error() string { return "[" + result + "]" } -func (agg aggregate) visit(f func(err error)) { +func (agg aggregate) Is(target error) bool { + return agg.visit(func(err error) bool { + return errors.Is(err, target) + }) +} + +func (agg aggregate) visit(f func(err error) bool) bool { for _, err := range agg { switch err := err.(type) { case aggregate: - err.visit(f) + if match := err.visit(f); match { + return match + } case Aggregate: for _, nestedErr := range err.Errors() { - f(nestedErr) + if match := f(nestedErr); match { + return match + } } default: - f(err) + if match := f(err); match { + return match + } } } + + return false } // Errors is part of the Aggregate interface. diff --git a/staging/src/k8s.io/apimachinery/pkg/util/errors/errors_test.go b/staging/src/k8s.io/apimachinery/pkg/util/errors/errors_test.go index d70a4d51a0b..55d253bb20d 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/errors/errors_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/errors/errors_test.go @@ -17,6 +17,7 @@ limitations under the License. package errors import ( + "errors" "fmt" "reflect" "sort" @@ -430,3 +431,100 @@ func TestAggregateGoroutines(t *testing.T) { } } } + +type alwaysMatchingError struct{} + +func (_ alwaysMatchingError) Error() string { + return "error" +} + +func (_ alwaysMatchingError) Is(_ error) bool { + return true +} + +type someError struct{ msg string } + +func (se someError) Error() string { + if se.msg != "" { + return se.msg + } + return "err" +} + +func TestAggregateWithErrorsIs(t *testing.T) { + testCases := []struct { + name string + err error + matchAgainst error + expectMatch bool + }{ + { + name: "no match", + err: aggregate{errors.New("my-error"), errors.New("my-other-error")}, + matchAgainst: fmt.Errorf("no entry %s", "here"), + }, + { + name: "match via .Is()", + err: aggregate{errors.New("forbidden"), alwaysMatchingError{}}, + matchAgainst: errors.New("unauthorized"), + expectMatch: true, + }, + { + name: "match via equality", + err: aggregate{errors.New("err"), someError{}}, + matchAgainst: someError{}, + expectMatch: true, + }, + { + name: "match via nested aggregate", + err: aggregate{errors.New("closed today"), aggregate{aggregate{someError{}}}}, + matchAgainst: someError{}, + expectMatch: true, + }, + { + name: "match via wrapped aggregate", + err: fmt.Errorf("wrap: %w", aggregate{errors.New("err"), someError{}}), + matchAgainst: someError{}, + expectMatch: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := errors.Is(tc.err, tc.matchAgainst) + if result != tc.expectMatch { + t.Errorf("expected match: %t, got match: %t", tc.expectMatch, result) + } + }) + } +} + +type accessTrackingError struct { + wasAccessed bool +} + +func (accessTrackingError) Error() string { + return "err" +} + +func (ate *accessTrackingError) Is(_ error) bool { + ate.wasAccessed = true + return true +} + +var _ error = &accessTrackingError{} + +func TestErrConfigurationInvalidWithErrorsIsShortCircuitsOnFirstMatch(t *testing.T) { + errC := aggregate{&accessTrackingError{}, &accessTrackingError{}} + _ = errors.Is(errC, &accessTrackingError{}) + + var numAccessed int + for _, err := range errC { + if ate := err.(*accessTrackingError); ate.wasAccessed { + numAccessed++ + } + } + if numAccessed != 1 { + t.Errorf("expected exactly one error to get accessed, got %d", numAccessed) + } +} diff --git a/staging/src/k8s.io/client-go/tools/clientcmd/validation.go b/staging/src/k8s.io/client-go/tools/clientcmd/validation.go index 2f927072bde..06d8f102a9a 100644 --- a/staging/src/k8s.io/client-go/tools/clientcmd/validation.go +++ b/staging/src/k8s.io/client-go/tools/clientcmd/validation.go @@ -86,11 +86,41 @@ func (e errConfigurationInvalid) Error() string { return fmt.Sprintf("invalid configuration: %v", utilerrors.NewAggregate(e).Error()) } -// Errors implements the AggregateError interface +// Errors implements the utilerrors.Aggregate interface func (e errConfigurationInvalid) Errors() []error { return e } +// Is implements the utilerrors.Aggregate interface +func (e errConfigurationInvalid) Is(target error) bool { + return e.visit(func(err error) bool { + return errors.Is(err, target) + }) +} + +func (e errConfigurationInvalid) visit(f func(err error) bool) bool { + for _, err := range e { + switch err := err.(type) { + case errConfigurationInvalid: + if match := err.visit(f); match { + return match + } + case utilerrors.Aggregate: + for _, nestedErr := range err.Errors() { + if match := f(nestedErr); match { + return match + } + } + default: + if match := f(err); match { + return match + } + } + } + + return false +} + // IsConfigurationInvalid returns true if the provided error indicates the configuration is invalid. func IsConfigurationInvalid(err error) bool { switch err.(type) { diff --git a/staging/src/k8s.io/client-go/tools/clientcmd/validation_test.go b/staging/src/k8s.io/client-go/tools/clientcmd/validation_test.go index 1680eeb1aa9..caf67054c16 100644 --- a/staging/src/k8s.io/client-go/tools/clientcmd/validation_test.go +++ b/staging/src/k8s.io/client-go/tools/clientcmd/validation_test.go @@ -17,6 +17,8 @@ limitations under the License. package clientcmd import ( + "errors" + "fmt" "io/ioutil" "os" "strings" @@ -569,3 +571,100 @@ func (c configValidationTest) testAuthInfo(authInfoName string, t *testing.T) { } } } + +type alwaysMatchingError struct{} + +func (_ alwaysMatchingError) Error() string { + return "error" +} + +func (_ alwaysMatchingError) Is(_ error) bool { + return true +} + +type someError struct{ msg string } + +func (se someError) Error() string { + if se.msg != "" { + return se.msg + } + return "err" +} + +func TestErrConfigurationInvalidWithErrorsIs(t *testing.T) { + testCases := []struct { + name string + err error + matchAgainst error + expectMatch bool + }{ + { + name: "no match", + err: errConfigurationInvalid{errors.New("my-error"), errors.New("my-other-error")}, + matchAgainst: fmt.Errorf("no entry %s", "here"), + }, + { + name: "match via .Is()", + err: errConfigurationInvalid{errors.New("forbidden"), alwaysMatchingError{}}, + matchAgainst: errors.New("unauthorized"), + expectMatch: true, + }, + { + name: "match via equality", + err: errConfigurationInvalid{errors.New("err"), someError{}}, + matchAgainst: someError{}, + expectMatch: true, + }, + { + name: "match via nested aggregate", + err: errConfigurationInvalid{errors.New("closed today"), errConfigurationInvalid{errConfigurationInvalid{someError{}}}}, + matchAgainst: someError{}, + expectMatch: true, + }, + { + name: "match via wrapped aggregate", + err: fmt.Errorf("wrap: %w", errConfigurationInvalid{errors.New("err"), someError{}}), + matchAgainst: someError{}, + expectMatch: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := errors.Is(tc.err, tc.matchAgainst) + if result != tc.expectMatch { + t.Errorf("expected match: %t, got match: %t", tc.expectMatch, result) + } + }) + } +} + +type accessTrackingError struct { + wasAccessed bool +} + +func (accessTrackingError) Error() string { + return "err" +} + +func (ate *accessTrackingError) Is(_ error) bool { + ate.wasAccessed = true + return true +} + +var _ error = &accessTrackingError{} + +func TestErrConfigurationInvalidWithErrorsIsShortCircuitsOnFirstMatch(t *testing.T) { + errC := errConfigurationInvalid{&accessTrackingError{}, &accessTrackingError{}} + _ = errors.Is(errC, &accessTrackingError{}) + + var numAccessed int + for _, err := range errC { + if ate := err.(*accessTrackingError); ate.wasAccessed { + numAccessed++ + } + } + if numAccessed != 1 { + t.Errorf("expected exactly one error to get accessed, got %d", numAccessed) + } +}