From 682f2a5a795a4e3a4f76f23ceb03b818af0d69a5 Mon Sep 17 00:00:00 2001 From: Tim Hockin Date: Tue, 3 Nov 2015 16:08:20 -0800 Subject: [PATCH] Stronger typing for validation ErrorList --- examples/examples_test.go | 5 +- pkg/api/errors/errors.go | 16 +++--- pkg/api/rest/update.go | 4 +- pkg/api/validation/schema.go | 15 +++-- pkg/api/validation/validation.go | 4 +- pkg/api/validation/validation_test.go | 43 +++++++------- .../extensions/validation/validation_test.go | 15 +++-- pkg/kubectl/cmd/log.go | 3 +- pkg/kubelet/config/config.go | 5 +- pkg/kubelet/config/http_test.go | 3 +- pkg/util/validation/errors.go | 57 +++++++++++++------ pkg/util/validation/errors_test.go | 30 +++++++++- 12 files changed, 122 insertions(+), 78 deletions(-) diff --git a/examples/examples_test.go b/examples/examples_test.go index f75f9d604e6..c8ae44e8463 100644 --- a/examples/examples_test.go +++ b/examples/examples_test.go @@ -33,12 +33,13 @@ import ( expvalidation "k8s.io/kubernetes/pkg/apis/extensions/validation" "k8s.io/kubernetes/pkg/capabilities" "k8s.io/kubernetes/pkg/runtime" + utilvalidation "k8s.io/kubernetes/pkg/util/validation" "k8s.io/kubernetes/pkg/util/yaml" schedulerapi "k8s.io/kubernetes/plugin/pkg/scheduler/api" schedulerapilatest "k8s.io/kubernetes/plugin/pkg/scheduler/api/latest" ) -func validateObject(obj runtime.Object) (errors []error) { +func validateObject(obj runtime.Object) (errors utilvalidation.ErrorList) { switch t := obj.(type) { case *api.ReplicationController: if t.Namespace == "" { @@ -122,7 +123,7 @@ func validateObject(obj runtime.Object) (errors []error) { } errors = expvalidation.ValidateDaemonSet(t) default: - return []error{fmt.Errorf("no validation defined for %#v", obj)} + return utilvalidation.ErrorList{utilvalidation.NewInternalError("", fmt.Errorf("no validation defined for %#v", obj))} } return errors } diff --git a/pkg/api/errors/errors.go b/pkg/api/errors/errors.go index 0935ed88fb8..ea62742c3ef 100644 --- a/pkg/api/errors/errors.go +++ b/pkg/api/errors/errors.go @@ -24,7 +24,6 @@ import ( "k8s.io/kubernetes/pkg/api/unversioned" "k8s.io/kubernetes/pkg/runtime" - utilerrors "k8s.io/kubernetes/pkg/util/errors" "k8s.io/kubernetes/pkg/util/validation" ) @@ -162,13 +161,12 @@ func NewConflict(kind, name string, err error) error { func NewInvalid(kind, name string, errs validation.ErrorList) error { causes := make([]unversioned.StatusCause, 0, len(errs)) for i := range errs { - if err, ok := errs[i].(*validation.Error); ok { - causes = append(causes, unversioned.StatusCause{ - Type: unversioned.CauseType(err.Type), - Message: err.ErrorBody(), - Field: err.Field, - }) - } + err := errs[i] + causes = append(causes, unversioned.StatusCause{ + Type: unversioned.CauseType(err.Type), + Message: err.ErrorBody(), + Field: err.Field, + }) } return &StatusError{unversioned.Status{ Status: unversioned.StatusFailure, @@ -179,7 +177,7 @@ func NewInvalid(kind, name string, errs validation.ErrorList) error { Name: name, Causes: causes, }, - Message: fmt.Sprintf("%s %q is invalid: %v", kind, name, utilerrors.NewAggregate(errs)), + Message: fmt.Sprintf("%s %q is invalid: %v", kind, name, errs.ToAggregate()), }} } diff --git a/pkg/api/rest/update.go b/pkg/api/rest/update.go index 7a2e9da1b0a..c8e9e30d784 100644 --- a/pkg/api/rest/update.go +++ b/pkg/api/rest/update.go @@ -57,11 +57,11 @@ func validateCommonFields(obj, old runtime.Object) utilvalidation.ErrorList { allErrs := utilvalidation.ErrorList{} objectMeta, err := api.ObjectMetaFor(obj) if err != nil { - return append(allErrs, errors.NewInternalError(err)) + return append(allErrs, utilvalidation.NewInternalError("metadata", err)) } oldObjectMeta, err := api.ObjectMetaFor(old) if err != nil { - return append(allErrs, errors.NewInternalError(err)) + return append(allErrs, utilvalidation.NewInternalError("metadata", err)) } allErrs = append(allErrs, validation.ValidateObjectMetaUpdate(objectMeta, oldObjectMeta)...) diff --git a/pkg/api/validation/schema.go b/pkg/api/validation/schema.go index 57755cd64d6..24c1e26199f 100644 --- a/pkg/api/validation/schema.go +++ b/pkg/api/validation/schema.go @@ -27,7 +27,6 @@ import ( "github.com/golang/glog" apiutil "k8s.io/kubernetes/pkg/api/util" utilerrors "k8s.io/kubernetes/pkg/util/errors" - "k8s.io/kubernetes/pkg/util/validation" "k8s.io/kubernetes/pkg/util/yaml" ) @@ -67,11 +66,11 @@ func NewSwaggerSchemaFromBytes(data []byte) (Schema, error) { return schema, nil } -// validateList unpack a list and validate every item in the list. +// validateList unpacks a list and validate every item in the list. // It return nil if every item is ok. // Otherwise it return an error list contain errors of every item. -func (s *SwaggerSchema) validateList(obj map[string]interface{}) validation.ErrorList { - allErrs := validation.ErrorList{} +func (s *SwaggerSchema) validateList(obj map[string]interface{}) []error { + allErrs := []error{} items, exists := obj["items"] if !exists { return append(allErrs, fmt.Errorf("no items field in %#v", obj)) @@ -160,8 +159,8 @@ func (s *SwaggerSchema) ValidateBytes(data []byte) error { return utilerrors.NewAggregate(allErrs) } -func (s *SwaggerSchema) ValidateObject(obj interface{}, fieldName, typeName string) validation.ErrorList { - allErrs := validation.ErrorList{} +func (s *SwaggerSchema) ValidateObject(obj interface{}, fieldName, typeName string) []error { + allErrs := []error{} models := s.api.Models model, ok := models.At(typeName) if !ok { @@ -215,7 +214,7 @@ func (s *SwaggerSchema) ValidateObject(obj interface{}, fieldName, typeName stri // This matches type name in the swagger spec, such as "v1.Binding". var versionRegexp = regexp.MustCompile(`^v.+\..*`) -func (s *SwaggerSchema) validateField(value interface{}, fieldName, fieldType string, fieldDetails *swagger.ModelProperty) validation.ErrorList { +func (s *SwaggerSchema) validateField(value interface{}, fieldName, fieldType string, fieldDetails *swagger.ModelProperty) []error { // TODO: caesarxuchao: because we have multiple group/versions and objects // may reference objects in other group, the commented out way of checking // if a filedType is a type defined by us is outdated. We use a hacky way @@ -229,7 +228,7 @@ func (s *SwaggerSchema) validateField(value interface{}, fieldName, fieldType st // if strings.HasPrefix(fieldType, apiVersion) { return s.ValidateObject(value, fieldName, fieldType) } - allErrs := validation.ErrorList{} + allErrs := []error{} switch fieldType { case "string": // Be loose about what we accept for 'string' since we use IntOrString in a couple of places diff --git a/pkg/api/validation/validation.go b/pkg/api/validation/validation.go index 3224e4ef0fb..47469208b9f 100644 --- a/pkg/api/validation/validation.go +++ b/pkg/api/validation/validation.go @@ -1480,7 +1480,7 @@ func ValidateNodeUpdate(node, oldNode *api.Node) validation.ErrorList { addresses := make(map[api.NodeAddress]bool) for _, address := range node.Status.Addresses { if _, ok := addresses[address]; ok { - allErrs = append(allErrs, fmt.Errorf("duplicate node addresses found")) + allErrs = append(allErrs, validation.NewFieldDuplicate("addresses", address)) } addresses[address] = true } @@ -1500,7 +1500,7 @@ func ValidateNodeUpdate(node, oldNode *api.Node) validation.ErrorList { // TODO: Add a 'real' error type for this error and provide print actual diffs. if !api.Semantic.DeepEqual(oldNode, node) { glog.V(4).Infof("Update failed validation %#v vs %#v", oldNode, node) - allErrs = append(allErrs, fmt.Errorf("update contains more than labels or capacity changes")) + allErrs = append(allErrs, validation.NewFieldForbidden("", "update contains more than labels or capacity changes")) } return allErrs diff --git a/pkg/api/validation/validation_test.go b/pkg/api/validation/validation_test.go index 21c20ec0bbe..71b39eae611 100644 --- a/pkg/api/validation/validation_test.go +++ b/pkg/api/validation/validation_test.go @@ -28,7 +28,6 @@ import ( "k8s.io/kubernetes/pkg/api/testapi" "k8s.io/kubernetes/pkg/api/unversioned" "k8s.io/kubernetes/pkg/capabilities" - utilerrors "k8s.io/kubernetes/pkg/util/errors" "k8s.io/kubernetes/pkg/util/intstr" "k8s.io/kubernetes/pkg/util/sets" "k8s.io/kubernetes/pkg/util/validation" @@ -36,7 +35,7 @@ import ( func expectPrefix(t *testing.T, prefix string, errs validation.ErrorList) { for i := range errs { - if f, p := errs[i].(*validation.Error).Field, prefix; !strings.HasPrefix(f, p) { + if f, p := errs[i].Field, prefix; !strings.HasPrefix(f, p) { t.Errorf("expected prefix '%s' for field '%s' (%v)", p, f, errs[i]) } } @@ -150,7 +149,7 @@ func TestValidateLabels(t *testing.T) { if len(errs) != 1 { t.Errorf("case[%d] expected failure", i) } else { - detail := errs[0].(*validation.Error).Detail + detail := errs[0].Detail if detail != qualifiedNameErrorMsg { t.Errorf("error detail %s should be equal %s", detail, qualifiedNameErrorMsg) } @@ -168,7 +167,7 @@ func TestValidateLabels(t *testing.T) { if len(errs) != 1 { t.Errorf("case[%d] expected failure", i) } else { - detail := errs[0].(*validation.Error).Detail + detail := errs[0].Detail if detail != labelValueErrorMsg { t.Errorf("error detail %s should be equal %s", detail, labelValueErrorMsg) } @@ -215,7 +214,7 @@ func TestValidateAnnotations(t *testing.T) { if len(errs) != 1 { t.Errorf("case[%d] expected failure", i) } - detail := errs[0].(*validation.Error).Detail + detail := errs[0].Detail if detail != qualifiedNameErrorMsg { t.Errorf("error detail %s should be equal %s", detail, qualifiedNameErrorMsg) } @@ -568,13 +567,13 @@ func TestValidateVolumes(t *testing.T) { continue } for i := range errs { - if errs[i].(*validation.Error).Type != v.T { + if errs[i].Type != v.T { t.Errorf("%s: expected errors to have type %s: %v", k, v.T, errs[i]) } - if errs[i].(*validation.Error).Field != v.F { + if errs[i].Field != v.F { t.Errorf("%s: expected errors to have field %s: %v", k, v.F, errs[i]) } - detail := errs[i].(*validation.Error).Detail + detail := errs[i].Detail if detail != v.D { t.Errorf("%s: expected error detail \"%s\", got \"%s\"", k, v.D, detail) } @@ -627,13 +626,13 @@ func TestValidatePorts(t *testing.T) { t.Errorf("expected failure for %s", k) } for i := range errs { - if errs[i].(*validation.Error).Type != v.T { + if errs[i].Type != v.T { t.Errorf("%s: expected errors to have type %s: %v", k, v.T, errs[i]) } - if errs[i].(*validation.Error).Field != v.F { + if errs[i].Field != v.F { t.Errorf("%s: expected errors to have field %s: %v", k, v.F, errs[i]) } - detail := errs[i].(*validation.Error).Detail + detail := errs[i].Detail if detail != v.D { t.Errorf("%s: expected error detail either empty or %s, got %s", k, v.D, detail) } @@ -772,7 +771,7 @@ func TestValidateEnv(t *testing.T) { t.Errorf("expected failure for %s", tc.name) } else { for i := range errs { - str := errs[i].(*validation.Error).Error() + str := errs[i].Error() if str != "" && str != tc.expectedError { t.Errorf("%s: expected error detail either empty or %s, got %s", tc.name, tc.expectedError, str) } @@ -2108,7 +2107,7 @@ func TestValidateService(t *testing.T) { tc.tweakSvc(&svc) errs := ValidateService(&svc) if len(errs) != tc.numErrs { - t.Errorf("Unexpected error list for case %q: %v", tc.name, utilerrors.NewAggregate(errs)) + t.Errorf("Unexpected error list for case %q: %v", tc.name, errs.ToAggregate()) } } } @@ -2560,7 +2559,7 @@ func TestValidateReplicationController(t *testing.T) { t.Errorf("expected failure for %s", k) } for i := range errs { - field := errs[i].(*validation.Error).Field + field := errs[i].Field if !strings.HasPrefix(field, "spec.template.") && field != "metadata.name" && field != "metadata.namespace" && @@ -2676,7 +2675,7 @@ func TestValidateNode(t *testing.T) { t.Errorf("expected failure for %s", k) } for i := range errs { - field := errs[i].(*validation.Error).Field + field := errs[i].Field expectedFields := map[string]bool{ "metadata.name": true, "metadata.labels": true, @@ -2974,7 +2973,7 @@ func TestValidateServiceUpdate(t *testing.T) { tc.tweakSvc(&oldSvc, &newSvc) errs := ValidateServiceUpdate(&newSvc, &oldSvc) if len(errs) != tc.numErrs { - t.Errorf("Unexpected error list for case %q: %v", tc.name, utilerrors.NewAggregate(errs)) + t.Errorf("Unexpected error list for case %q: %v", tc.name, errs.ToAggregate()) } } } @@ -3008,7 +3007,7 @@ func TestValidateResourceNames(t *testing.T) { } else if len(err) == 0 && !item.success { t.Errorf("expected failure for input %q", item.input) for i := range err { - detail := err[i].(*validation.Error).Detail + detail := err[i].Detail if detail != "" && detail != qualifiedNameErrorMsg { t.Errorf("%d: expected error detail either empty or %s, got %s", k, qualifiedNameErrorMsg, detail) } @@ -3224,7 +3223,7 @@ func TestValidateLimitRange(t *testing.T) { t.Errorf("expected failure for %s", k) } for i := range errs { - detail := errs[i].(*validation.Error).Detail + detail := errs[i].Detail if detail != v.D { t.Errorf("%s: expected error detail either empty or %s, got %s", k, v.D, detail) } @@ -3329,8 +3328,8 @@ func TestValidateResourceQuota(t *testing.T) { t.Errorf("expected failure for %s", k) } for i := range errs { - field := errs[i].(*validation.Error).Field - detail := errs[i].(*validation.Error).Detail + field := errs[i].Field + detail := errs[i].Detail if field != "metadata.name" && field != "metadata.namespace" && !api.IsStandardResourceName(field) { t.Errorf("%s: missing prefix for: %v", k, field) } @@ -3937,7 +3936,7 @@ func TestValidateEndpoints(t *testing.T) { } for k, v := range errorCases { - if errs := ValidateEndpoints(&v.endpoints); len(errs) == 0 || errs[0].(*validation.Error).Type != v.errorType || !strings.Contains(errs[0].(*validation.Error).Detail, v.errorDetail) { + if errs := ValidateEndpoints(&v.endpoints); len(errs) == 0 || errs[0].Type != v.errorType || !strings.Contains(errs[0].Detail, v.errorDetail) { t.Errorf("Expected error type %s with detail %s for %s, got %v", v.errorType, v.errorDetail, k, errs) } } @@ -4017,7 +4016,7 @@ func TestValidateSecurityContext(t *testing.T) { }, } for k, v := range errorCases { - if errs := ValidateSecurityContext(v.sc); len(errs) == 0 || errs[0].(*validation.Error).Type != v.errorType || errs[0].(*validation.Error).Detail != v.errorDetail { + if errs := ValidateSecurityContext(v.sc); len(errs) == 0 || errs[0].Type != v.errorType || errs[0].Detail != v.errorDetail { t.Errorf("Expected error type %s with detail %s for %s, got %v", v.errorType, v.errorDetail, k, errs) } } diff --git a/pkg/apis/extensions/validation/validation_test.go b/pkg/apis/extensions/validation/validation_test.go index 94d74de2fe7..497e88e88fc 100644 --- a/pkg/apis/extensions/validation/validation_test.go +++ b/pkg/apis/extensions/validation/validation_test.go @@ -24,7 +24,6 @@ import ( "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/apis/extensions" "k8s.io/kubernetes/pkg/util/intstr" - "k8s.io/kubernetes/pkg/util/validation" ) func TestValidateHorizontalPodAutoscaler(t *testing.T) { @@ -675,7 +674,7 @@ func TestValidateDaemonSet(t *testing.T) { t.Errorf("expected failure for %s", k) } for i := range errs { - field := errs[i].(*validation.Error).Field + field := errs[i].Field if !strings.HasPrefix(field, "spec.template.") && field != "metadata.name" && field != "metadata.namespace" && @@ -918,9 +917,9 @@ func TestValidateJob(t *testing.T) { t.Errorf("expected failure for %s", k) } else { s := strings.Split(k, ":") - err := errs[0].(*validation.Error) + err := errs[0] if err.Field != s[0] || !strings.Contains(err.Error(), s[1]) { - t.Errorf("unexpected error: %v, expected: %s", errs[0], k) + t.Errorf("unexpected error: %v, expected: %s", err, k) } } } @@ -1019,9 +1018,9 @@ func TestValidateIngress(t *testing.T) { t.Errorf("expected failure for %s", k) } else { s := strings.Split(k, ":") - err := errs[0].(*validation.Error) + err := errs[0] if err.Field != s[0] || !strings.Contains(err.Error(), s[1]) { - t.Errorf("unexpected error: %v, expected: %s", errs[0], k) + t.Errorf("unexpected error: %v, expected: %s", err, k) } } } @@ -1111,9 +1110,9 @@ func TestValidateIngressStatusUpdate(t *testing.T) { t.Errorf("expected failure for %s", k) } else { s := strings.Split(k, ":") - err := errs[0].(*validation.Error) + err := errs[0] if err.Field != s[0] || !strings.Contains(err.Error(), s[1]) { - t.Errorf("unexpected error: %v, expected: %s", errs[0], k) + t.Errorf("unexpected error: %v, expected: %s", err, k) } } } diff --git a/pkg/kubectl/cmd/log.go b/pkg/kubectl/cmd/log.go index 722b1f10166..430511d6b30 100644 --- a/pkg/kubectl/cmd/log.go +++ b/pkg/kubectl/cmd/log.go @@ -32,7 +32,6 @@ import ( cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util" "k8s.io/kubernetes/pkg/kubectl/resource" "k8s.io/kubernetes/pkg/runtime" - kerrors "k8s.io/kubernetes/pkg/util/errors" ) const ( @@ -169,7 +168,7 @@ func (o LogsOptions) Validate() error { return errors.New("unexpected log options object") } if errs := validation.ValidatePodLogOptions(logOptions); len(errs) > 0 { - return kerrors.NewAggregate(errs) + return errs.ToAggregate() } return nil diff --git a/pkg/kubelet/config/config.go b/pkg/kubelet/config/config.go index ff0bd1a0459..737c6aa7425 100644 --- a/pkg/kubelet/config/config.go +++ b/pkg/kubelet/config/config.go @@ -29,7 +29,6 @@ import ( kubetypes "k8s.io/kubernetes/pkg/kubelet/types" kubeletutil "k8s.io/kubernetes/pkg/kubelet/util" "k8s.io/kubernetes/pkg/util/config" - utilerrors "k8s.io/kubernetes/pkg/util/errors" "k8s.io/kubernetes/pkg/util/sets" utilvalidation "k8s.io/kubernetes/pkg/util/validation" ) @@ -310,7 +309,7 @@ func (s *podStorage) seenSources(sources ...string) bool { func filterInvalidPods(pods []*api.Pod, source string, recorder record.EventRecorder) (filtered []*api.Pod) { names := sets.String{} for i, pod := range pods { - var errlist []error + var errlist utilvalidation.ErrorList if errs := validation.ValidatePod(pod); len(errs) != 0 { errlist = append(errlist, errs...) // If validation fails, don't trust it any further - @@ -325,7 +324,7 @@ func filterInvalidPods(pods []*api.Pod, source string, recorder record.EventReco } if len(errlist) > 0 { name := bestPodIdentString(pod) - err := utilerrors.NewAggregate(errlist) + err := errlist.ToAggregate() glog.Warningf("Pod[%d] (%s) from %s failed validation, ignoring: %v", i+1, name, source, err) recorder.Eventf(pod, kubecontainer.FailedValidation, "Error validating pod %s from %s, ignoring: %v", name, source, err) continue diff --git a/pkg/kubelet/config/http_test.go b/pkg/kubelet/config/http_test.go index 84115a4cc7f..f257bc3fc68 100644 --- a/pkg/kubelet/config/http_test.go +++ b/pkg/kubelet/config/http_test.go @@ -30,7 +30,6 @@ import ( kubetypes "k8s.io/kubernetes/pkg/kubelet/types" "k8s.io/kubernetes/pkg/runtime" "k8s.io/kubernetes/pkg/util" - utilerrors "k8s.io/kubernetes/pkg/util/errors" ) func TestURLErrorNotExistNoUpdate(t *testing.T) { @@ -286,7 +285,7 @@ func TestExtractPodsFromHTTP(t *testing.T) { } for _, pod := range update.Pods { if errs := validation.ValidatePod(pod); len(errs) != 0 { - t.Errorf("%s: Expected no validation errors on %#v, Got %v", testCase.desc, pod, utilerrors.NewAggregate(errs)) + t.Errorf("%s: Expected no validation errors on %#v, Got %v", testCase.desc, pod, errs.ToAggregate()) } } } diff --git a/pkg/util/validation/errors.go b/pkg/util/validation/errors.go index 17d5e4630f4..56d39cf89b6 100644 --- a/pkg/util/validation/errors.go +++ b/pkg/util/validation/errors.go @@ -46,7 +46,7 @@ func (v *Error) Error() string { func (v *Error) ErrorBody() string { var s string switch v.Type { - case ErrorTypeRequired, ErrorTypeTooLong: + case ErrorTypeRequired, ErrorTypeTooLong, ErrorTypeInternal: s = spew.Sprintf("%s", v.Type) default: s = spew.Sprintf("%s '%+v'", v.Type, v.BadValue) @@ -89,6 +89,9 @@ const ( // This is similar to ErrorTypeInvalid, but the error will not include the // too-long value. See NewFieldTooLong. ErrorTypeTooLong ErrorType = "FieldValueTooLong" + // ErrorTypeInternal is used to report other errors that are not related + // to user input. + ErrorTypeInternal ErrorType = "InternalError" ) // String converts a ErrorType into its corresponding canonical error message. @@ -108,6 +111,8 @@ func (t ErrorType) String() string { return "forbidden" case ErrorTypeTooLong: return "too long" + case ErrorTypeInternal: + return "internal error" default: panic(fmt.Sprintf("unrecognized validation error: %q", t)) return "" @@ -166,24 +171,27 @@ func NewFieldTooLong(field string, value interface{}, maxLength int) *Error { return &Error{ErrorTypeTooLong, field, value, fmt.Sprintf("must have at most %d characters", maxLength)} } +// NewInternalError returns a *Error indicating "internal error". This is used +// to signal that an error was found that was not directly related to user +// input. The err argument must be non-nil. +func NewInternalError(field string, err error) *Error { + return &Error{ErrorTypeInternal, field, nil, err.Error()} +} + // ErrorList holds a set of errors. -type ErrorList []error +type ErrorList []*Error // Prefix adds a prefix to the Field of every Error in the list. // Returns the list for convenience. func (list ErrorList) Prefix(prefix string) ErrorList { for i := range list { - if err, ok := list[i].(*Error); ok { - if strings.HasPrefix(err.Field, "[") { - err.Field = prefix + err.Field - } else if len(err.Field) != 0 { - err.Field = prefix + "." + err.Field - } else { - err.Field = prefix - } - list[i] = err + err := list[i] + if strings.HasPrefix(err.Field, "[") { + err.Field = prefix + err.Field + } else if len(err.Field) != 0 { + err.Field = prefix + "." + err.Field } else { - panic(fmt.Sprintf("Programmer error: ErrorList holds non-Error: %#v", list[i])) + err.Field = prefix } } return list @@ -206,13 +214,30 @@ func NewErrorTypeMatcher(t ErrorType) utilerrors.Matcher { } } +// ToAggregate converts the ErrorList into an errors.Aggregate. +func (list ErrorList) ToAggregate() utilerrors.Aggregate { + errs := make([]error, len(list)) + for i := range list { + errs[i] = list[i] + } + return utilerrors.NewAggregate(errs) +} + +func fromAggregate(agg utilerrors.Aggregate) ErrorList { + errs := agg.Errors() + list := make(ErrorList, len(errs)) + for i := range errs { + list[i] = errs[i].(*Error) + } + return list +} + // Filter removes items from the ErrorList that match the provided fns. func (list ErrorList) Filter(fns ...utilerrors.Matcher) ErrorList { - err := utilerrors.FilterOut(utilerrors.NewAggregate(list), fns...) + err := utilerrors.FilterOut(list.ToAggregate(), fns...) if err == nil { return nil } - // FilterOut that takes an Aggregate returns an Aggregate - agg := err.(utilerrors.Aggregate) - return ErrorList(agg.Errors()) + // FilterOut takes an Aggregate and returns an Aggregate + return fromAggregate(err.(utilerrors.Aggregate)) } diff --git a/pkg/util/validation/errors_test.go b/pkg/util/validation/errors_test.go index 777c9abcb3e..d4fb1884295 100644 --- a/pkg/util/validation/errors_test.go +++ b/pkg/util/validation/errors_test.go @@ -17,6 +17,7 @@ limitations under the License. package validation import ( + "fmt" "strings" "testing" ) @@ -46,6 +47,10 @@ func TestMakeFuncs(t *testing.T) { func() *Error { return NewFieldRequired("f") }, ErrorTypeRequired, }, + { + func() *Error { return NewInternalError("f", fmt.Errorf("e")) }, + ErrorTypeInternal, + }, } for _, testCase := range testCases { @@ -93,6 +98,27 @@ func TestErrorUsefulMessage(t *testing.T) { } } +func TestToAggregate(t *testing.T) { + testCases := []ErrorList{ + nil, + {}, + {NewFieldInvalid("f", "v", "d")}, + {NewFieldInvalid("f", "v", "d"), NewInternalError("", fmt.Errorf("e"))}, + } + for i, tc := range testCases { + agg := tc.ToAggregate() + if len(tc) == 0 { + if agg != nil { + t.Errorf("[%d] Expected nil, got %#v", i, agg) + } + } else if agg == nil { + t.Errorf("[%d] Expected non-nil", i) + } else if len(tc) != len(agg.Errors()) { + t.Errorf("[%d] Expected %d, got %d", i, len(tc), len(agg.Errors())) + } + } +} + func TestErrListFilter(t *testing.T) { list := ErrorList{ NewFieldInvalid("test.field", "", ""), @@ -131,7 +157,7 @@ func TestErrListPrefix(t *testing.T) { if prefix == nil || len(prefix) != len(errList) { t.Errorf("Prefix should return self") } - if e, a := testCase.Expected, errList[0].(*Error).Field; e != a { + if e, a := testCase.Expected, errList[0].Field; e != a { t.Errorf("expected %s, got %s", e, a) } } @@ -161,7 +187,7 @@ func TestErrListPrefixIndex(t *testing.T) { if prefix == nil || len(prefix) != len(errList) { t.Errorf("PrefixIndex should return self") } - if e, a := testCase.Expected, errList[0].(*Error).Field; e != a { + if e, a := testCase.Expected, errList[0].Field; e != a { t.Errorf("expected %s, got %s", e, a) } }