Merge pull request #88465 from alvaroaleman/utilerrors-implement-errors-is

Utilerrors.Aggregate: Allow using with errors.Is()
This commit is contained in:
Kubernetes Prow Robot 2020-03-05 20:03:53 -08:00 committed by GitHub
commit c812375ed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 254 additions and 7 deletions

View File

@ -28,9 +28,14 @@ type MessageCountMap map[string]int
// Aggregate represents an object that contains multiple errors, but does not // Aggregate represents an object that contains multiple errors, but does not
// necessarily have singular semantic meaning. // necessarily have singular semantic meaning.
// The aggregate can be used with `errors.Is()` to check for the occurrence of
// a specific error type.
// Errors.As() is not supported, because the caller presumably cares about a
// specific error of potentially multiple that match the given type.
type Aggregate interface { type Aggregate interface {
error error
Errors() []error Errors() []error
Is(error) bool
} }
// NewAggregate converts a slice of errors into an Aggregate interface, which // NewAggregate converts a slice of errors into an Aggregate interface, which
@ -71,16 +76,17 @@ func (agg aggregate) Error() string {
} }
seenerrs := sets.NewString() seenerrs := sets.NewString()
result := "" result := ""
agg.visit(func(err error) { agg.visit(func(err error) bool {
msg := err.Error() msg := err.Error()
if seenerrs.Has(msg) { if seenerrs.Has(msg) {
return return false
} }
seenerrs.Insert(msg) seenerrs.Insert(msg)
if len(seenerrs) > 1 { if len(seenerrs) > 1 {
result += ", " result += ", "
} }
result += msg result += msg
return false
}) })
if len(seenerrs) == 1 { if len(seenerrs) == 1 {
return result return result
@ -88,19 +94,33 @@ func (agg aggregate) Error() string {
return "[" + result + "]" return "[" + result + "]"
} }
func (agg aggregate) visit(f func(err error)) { func (agg aggregate) Is(target error) bool {
return agg.visit(func(err error) bool {
return errors.Is(err, target)
})
}
func (agg aggregate) visit(f func(err error) bool) bool {
for _, err := range agg { for _, err := range agg {
switch err := err.(type) { switch err := err.(type) {
case aggregate: case aggregate:
err.visit(f) if match := err.visit(f); match {
return match
}
case Aggregate: case Aggregate:
for _, nestedErr := range err.Errors() { for _, nestedErr := range err.Errors() {
f(nestedErr) if match := f(nestedErr); match {
return match
}
} }
default: default:
f(err) if match := f(err); match {
return match
}
} }
} }
return false
} }
// Errors is part of the Aggregate interface. // Errors is part of the Aggregate interface.

View File

@ -17,6 +17,7 @@ limitations under the License.
package errors package errors
import ( import (
"errors"
"fmt" "fmt"
"reflect" "reflect"
"sort" "sort"
@ -430,3 +431,100 @@ func TestAggregateGoroutines(t *testing.T) {
} }
} }
} }
type alwaysMatchingError struct{}
func (_ alwaysMatchingError) Error() string {
return "error"
}
func (_ alwaysMatchingError) Is(_ error) bool {
return true
}
type someError struct{ msg string }
func (se someError) Error() string {
if se.msg != "" {
return se.msg
}
return "err"
}
func TestAggregateWithErrorsIs(t *testing.T) {
testCases := []struct {
name string
err error
matchAgainst error
expectMatch bool
}{
{
name: "no match",
err: aggregate{errors.New("my-error"), errors.New("my-other-error")},
matchAgainst: fmt.Errorf("no entry %s", "here"),
},
{
name: "match via .Is()",
err: aggregate{errors.New("forbidden"), alwaysMatchingError{}},
matchAgainst: errors.New("unauthorized"),
expectMatch: true,
},
{
name: "match via equality",
err: aggregate{errors.New("err"), someError{}},
matchAgainst: someError{},
expectMatch: true,
},
{
name: "match via nested aggregate",
err: aggregate{errors.New("closed today"), aggregate{aggregate{someError{}}}},
matchAgainst: someError{},
expectMatch: true,
},
{
name: "match via wrapped aggregate",
err: fmt.Errorf("wrap: %w", aggregate{errors.New("err"), someError{}}),
matchAgainst: someError{},
expectMatch: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := errors.Is(tc.err, tc.matchAgainst)
if result != tc.expectMatch {
t.Errorf("expected match: %t, got match: %t", tc.expectMatch, result)
}
})
}
}
type accessTrackingError struct {
wasAccessed bool
}
func (accessTrackingError) Error() string {
return "err"
}
func (ate *accessTrackingError) Is(_ error) bool {
ate.wasAccessed = true
return true
}
var _ error = &accessTrackingError{}
func TestErrConfigurationInvalidWithErrorsIsShortCircuitsOnFirstMatch(t *testing.T) {
errC := aggregate{&accessTrackingError{}, &accessTrackingError{}}
_ = errors.Is(errC, &accessTrackingError{})
var numAccessed int
for _, err := range errC {
if ate := err.(*accessTrackingError); ate.wasAccessed {
numAccessed++
}
}
if numAccessed != 1 {
t.Errorf("expected exactly one error to get accessed, got %d", numAccessed)
}
}

View File

@ -86,11 +86,41 @@ func (e errConfigurationInvalid) Error() string {
return fmt.Sprintf("invalid configuration: %v", utilerrors.NewAggregate(e).Error()) return fmt.Sprintf("invalid configuration: %v", utilerrors.NewAggregate(e).Error())
} }
// Errors implements the AggregateError interface // Errors implements the utilerrors.Aggregate interface
func (e errConfigurationInvalid) Errors() []error { func (e errConfigurationInvalid) Errors() []error {
return e return e
} }
// Is implements the utilerrors.Aggregate interface
func (e errConfigurationInvalid) Is(target error) bool {
return e.visit(func(err error) bool {
return errors.Is(err, target)
})
}
func (e errConfigurationInvalid) visit(f func(err error) bool) bool {
for _, err := range e {
switch err := err.(type) {
case errConfigurationInvalid:
if match := err.visit(f); match {
return match
}
case utilerrors.Aggregate:
for _, nestedErr := range err.Errors() {
if match := f(nestedErr); match {
return match
}
}
default:
if match := f(err); match {
return match
}
}
}
return false
}
// IsConfigurationInvalid returns true if the provided error indicates the configuration is invalid. // IsConfigurationInvalid returns true if the provided error indicates the configuration is invalid.
func IsConfigurationInvalid(err error) bool { func IsConfigurationInvalid(err error) bool {
switch err.(type) { switch err.(type) {

View File

@ -17,6 +17,8 @@ limitations under the License.
package clientcmd package clientcmd
import ( import (
"errors"
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"strings" "strings"
@ -569,3 +571,100 @@ func (c configValidationTest) testAuthInfo(authInfoName string, t *testing.T) {
} }
} }
} }
type alwaysMatchingError struct{}
func (_ alwaysMatchingError) Error() string {
return "error"
}
func (_ alwaysMatchingError) Is(_ error) bool {
return true
}
type someError struct{ msg string }
func (se someError) Error() string {
if se.msg != "" {
return se.msg
}
return "err"
}
func TestErrConfigurationInvalidWithErrorsIs(t *testing.T) {
testCases := []struct {
name string
err error
matchAgainst error
expectMatch bool
}{
{
name: "no match",
err: errConfigurationInvalid{errors.New("my-error"), errors.New("my-other-error")},
matchAgainst: fmt.Errorf("no entry %s", "here"),
},
{
name: "match via .Is()",
err: errConfigurationInvalid{errors.New("forbidden"), alwaysMatchingError{}},
matchAgainst: errors.New("unauthorized"),
expectMatch: true,
},
{
name: "match via equality",
err: errConfigurationInvalid{errors.New("err"), someError{}},
matchAgainst: someError{},
expectMatch: true,
},
{
name: "match via nested aggregate",
err: errConfigurationInvalid{errors.New("closed today"), errConfigurationInvalid{errConfigurationInvalid{someError{}}}},
matchAgainst: someError{},
expectMatch: true,
},
{
name: "match via wrapped aggregate",
err: fmt.Errorf("wrap: %w", errConfigurationInvalid{errors.New("err"), someError{}}),
matchAgainst: someError{},
expectMatch: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := errors.Is(tc.err, tc.matchAgainst)
if result != tc.expectMatch {
t.Errorf("expected match: %t, got match: %t", tc.expectMatch, result)
}
})
}
}
type accessTrackingError struct {
wasAccessed bool
}
func (accessTrackingError) Error() string {
return "err"
}
func (ate *accessTrackingError) Is(_ error) bool {
ate.wasAccessed = true
return true
}
var _ error = &accessTrackingError{}
func TestErrConfigurationInvalidWithErrorsIsShortCircuitsOnFirstMatch(t *testing.T) {
errC := errConfigurationInvalid{&accessTrackingError{}, &accessTrackingError{}}
_ = errors.Is(errC, &accessTrackingError{})
var numAccessed int
for _, err := range errC {
if ate := err.(*accessTrackingError); ate.wasAccessed {
numAccessed++
}
}
if numAccessed != 1 {
t.Errorf("expected exactly one error to get accessed, got %d", numAccessed)
}
}