mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-23 11:50:44 +00:00
Merge pull request #88465 from alvaroaleman/utilerrors-implement-errors-is
Utilerrors.Aggregate: Allow using with errors.Is()
This commit is contained in:
commit
c812375ed6
@ -28,9 +28,14 @@ type MessageCountMap map[string]int
|
|||||||
|
|
||||||
// Aggregate represents an object that contains multiple errors, but does not
|
// Aggregate represents an object that contains multiple errors, but does not
|
||||||
// necessarily have singular semantic meaning.
|
// necessarily have singular semantic meaning.
|
||||||
|
// The aggregate can be used with `errors.Is()` to check for the occurrence of
|
||||||
|
// a specific error type.
|
||||||
|
// Errors.As() is not supported, because the caller presumably cares about a
|
||||||
|
// specific error of potentially multiple that match the given type.
|
||||||
type Aggregate interface {
|
type Aggregate interface {
|
||||||
error
|
error
|
||||||
Errors() []error
|
Errors() []error
|
||||||
|
Is(error) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAggregate converts a slice of errors into an Aggregate interface, which
|
// NewAggregate converts a slice of errors into an Aggregate interface, which
|
||||||
@ -71,16 +76,17 @@ func (agg aggregate) Error() string {
|
|||||||
}
|
}
|
||||||
seenerrs := sets.NewString()
|
seenerrs := sets.NewString()
|
||||||
result := ""
|
result := ""
|
||||||
agg.visit(func(err error) {
|
agg.visit(func(err error) bool {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
if seenerrs.Has(msg) {
|
if seenerrs.Has(msg) {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
seenerrs.Insert(msg)
|
seenerrs.Insert(msg)
|
||||||
if len(seenerrs) > 1 {
|
if len(seenerrs) > 1 {
|
||||||
result += ", "
|
result += ", "
|
||||||
}
|
}
|
||||||
result += msg
|
result += msg
|
||||||
|
return false
|
||||||
})
|
})
|
||||||
if len(seenerrs) == 1 {
|
if len(seenerrs) == 1 {
|
||||||
return result
|
return result
|
||||||
@ -88,19 +94,33 @@ func (agg aggregate) Error() string {
|
|||||||
return "[" + result + "]"
|
return "[" + result + "]"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (agg aggregate) visit(f func(err error)) {
|
func (agg aggregate) Is(target error) bool {
|
||||||
|
return agg.visit(func(err error) bool {
|
||||||
|
return errors.Is(err, target)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (agg aggregate) visit(f func(err error) bool) bool {
|
||||||
for _, err := range agg {
|
for _, err := range agg {
|
||||||
switch err := err.(type) {
|
switch err := err.(type) {
|
||||||
case aggregate:
|
case aggregate:
|
||||||
err.visit(f)
|
if match := err.visit(f); match {
|
||||||
|
return match
|
||||||
|
}
|
||||||
case Aggregate:
|
case Aggregate:
|
||||||
for _, nestedErr := range err.Errors() {
|
for _, nestedErr := range err.Errors() {
|
||||||
f(nestedErr)
|
if match := f(nestedErr); match {
|
||||||
|
return match
|
||||||
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
f(err)
|
if match := f(err); match {
|
||||||
|
return match
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Errors is part of the Aggregate interface.
|
// Errors is part of the Aggregate interface.
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
package errors
|
package errors
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
@ -430,3 +431,100 @@ func TestAggregateGoroutines(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type alwaysMatchingError struct{}
|
||||||
|
|
||||||
|
func (_ alwaysMatchingError) Error() string {
|
||||||
|
return "error"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_ alwaysMatchingError) Is(_ error) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type someError struct{ msg string }
|
||||||
|
|
||||||
|
func (se someError) Error() string {
|
||||||
|
if se.msg != "" {
|
||||||
|
return se.msg
|
||||||
|
}
|
||||||
|
return "err"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAggregateWithErrorsIs(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
matchAgainst error
|
||||||
|
expectMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no match",
|
||||||
|
err: aggregate{errors.New("my-error"), errors.New("my-other-error")},
|
||||||
|
matchAgainst: fmt.Errorf("no entry %s", "here"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match via .Is()",
|
||||||
|
err: aggregate{errors.New("forbidden"), alwaysMatchingError{}},
|
||||||
|
matchAgainst: errors.New("unauthorized"),
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match via equality",
|
||||||
|
err: aggregate{errors.New("err"), someError{}},
|
||||||
|
matchAgainst: someError{},
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match via nested aggregate",
|
||||||
|
err: aggregate{errors.New("closed today"), aggregate{aggregate{someError{}}}},
|
||||||
|
matchAgainst: someError{},
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match via wrapped aggregate",
|
||||||
|
err: fmt.Errorf("wrap: %w", aggregate{errors.New("err"), someError{}}),
|
||||||
|
matchAgainst: someError{},
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := errors.Is(tc.err, tc.matchAgainst)
|
||||||
|
if result != tc.expectMatch {
|
||||||
|
t.Errorf("expected match: %t, got match: %t", tc.expectMatch, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type accessTrackingError struct {
|
||||||
|
wasAccessed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (accessTrackingError) Error() string {
|
||||||
|
return "err"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ate *accessTrackingError) Is(_ error) bool {
|
||||||
|
ate.wasAccessed = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ error = &accessTrackingError{}
|
||||||
|
|
||||||
|
func TestErrConfigurationInvalidWithErrorsIsShortCircuitsOnFirstMatch(t *testing.T) {
|
||||||
|
errC := aggregate{&accessTrackingError{}, &accessTrackingError{}}
|
||||||
|
_ = errors.Is(errC, &accessTrackingError{})
|
||||||
|
|
||||||
|
var numAccessed int
|
||||||
|
for _, err := range errC {
|
||||||
|
if ate := err.(*accessTrackingError); ate.wasAccessed {
|
||||||
|
numAccessed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if numAccessed != 1 {
|
||||||
|
t.Errorf("expected exactly one error to get accessed, got %d", numAccessed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -86,11 +86,41 @@ func (e errConfigurationInvalid) Error() string {
|
|||||||
return fmt.Sprintf("invalid configuration: %v", utilerrors.NewAggregate(e).Error())
|
return fmt.Sprintf("invalid configuration: %v", utilerrors.NewAggregate(e).Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Errors implements the AggregateError interface
|
// Errors implements the utilerrors.Aggregate interface
|
||||||
func (e errConfigurationInvalid) Errors() []error {
|
func (e errConfigurationInvalid) Errors() []error {
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Is implements the utilerrors.Aggregate interface
|
||||||
|
func (e errConfigurationInvalid) Is(target error) bool {
|
||||||
|
return e.visit(func(err error) bool {
|
||||||
|
return errors.Is(err, target)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errConfigurationInvalid) visit(f func(err error) bool) bool {
|
||||||
|
for _, err := range e {
|
||||||
|
switch err := err.(type) {
|
||||||
|
case errConfigurationInvalid:
|
||||||
|
if match := err.visit(f); match {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
case utilerrors.Aggregate:
|
||||||
|
for _, nestedErr := range err.Errors() {
|
||||||
|
if match := f(nestedErr); match {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if match := f(err); match {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// IsConfigurationInvalid returns true if the provided error indicates the configuration is invalid.
|
// IsConfigurationInvalid returns true if the provided error indicates the configuration is invalid.
|
||||||
func IsConfigurationInvalid(err error) bool {
|
func IsConfigurationInvalid(err error) bool {
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
package clientcmd
|
package clientcmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@ -569,3 +571,100 @@ func (c configValidationTest) testAuthInfo(authInfoName string, t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type alwaysMatchingError struct{}
|
||||||
|
|
||||||
|
func (_ alwaysMatchingError) Error() string {
|
||||||
|
return "error"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_ alwaysMatchingError) Is(_ error) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type someError struct{ msg string }
|
||||||
|
|
||||||
|
func (se someError) Error() string {
|
||||||
|
if se.msg != "" {
|
||||||
|
return se.msg
|
||||||
|
}
|
||||||
|
return "err"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrConfigurationInvalidWithErrorsIs(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
matchAgainst error
|
||||||
|
expectMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no match",
|
||||||
|
err: errConfigurationInvalid{errors.New("my-error"), errors.New("my-other-error")},
|
||||||
|
matchAgainst: fmt.Errorf("no entry %s", "here"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match via .Is()",
|
||||||
|
err: errConfigurationInvalid{errors.New("forbidden"), alwaysMatchingError{}},
|
||||||
|
matchAgainst: errors.New("unauthorized"),
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match via equality",
|
||||||
|
err: errConfigurationInvalid{errors.New("err"), someError{}},
|
||||||
|
matchAgainst: someError{},
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match via nested aggregate",
|
||||||
|
err: errConfigurationInvalid{errors.New("closed today"), errConfigurationInvalid{errConfigurationInvalid{someError{}}}},
|
||||||
|
matchAgainst: someError{},
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match via wrapped aggregate",
|
||||||
|
err: fmt.Errorf("wrap: %w", errConfigurationInvalid{errors.New("err"), someError{}}),
|
||||||
|
matchAgainst: someError{},
|
||||||
|
expectMatch: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := errors.Is(tc.err, tc.matchAgainst)
|
||||||
|
if result != tc.expectMatch {
|
||||||
|
t.Errorf("expected match: %t, got match: %t", tc.expectMatch, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type accessTrackingError struct {
|
||||||
|
wasAccessed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (accessTrackingError) Error() string {
|
||||||
|
return "err"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ate *accessTrackingError) Is(_ error) bool {
|
||||||
|
ate.wasAccessed = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ error = &accessTrackingError{}
|
||||||
|
|
||||||
|
func TestErrConfigurationInvalidWithErrorsIsShortCircuitsOnFirstMatch(t *testing.T) {
|
||||||
|
errC := errConfigurationInvalid{&accessTrackingError{}, &accessTrackingError{}}
|
||||||
|
_ = errors.Is(errC, &accessTrackingError{})
|
||||||
|
|
||||||
|
var numAccessed int
|
||||||
|
for _, err := range errC {
|
||||||
|
if ate := err.(*accessTrackingError); ate.wasAccessed {
|
||||||
|
numAccessed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if numAccessed != 1 {
|
||||||
|
t.Errorf("expected exactly one error to get accessed, got %d", numAccessed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user