diff --git a/pkg/util/ipset/ipset.go b/pkg/util/ipset/ipset.go index 04c920a83f9..2b145f08f8b 100644 --- a/pkg/util/ipset/ipset.go +++ b/pkg/util/ipset/ipset.go @@ -92,35 +92,33 @@ type IPSet struct { } // Validate checks if a given ipset is valid or not. -func (set *IPSet) Validate() bool { +func (set *IPSet) Validate() error { // Check if protocol is valid for `HashIPPort`, `HashIPPortIP` and `HashIPPortNet` type set. if set.SetType == HashIPPort || set.SetType == HashIPPortIP || set.SetType == HashIPPortNet { - if valid := validateHashFamily(set.HashFamily); !valid { - return false + if err := validateHashFamily(set.HashFamily); err != nil { + return err } } // check set type - if valid := validateIPSetType(set.SetType); !valid { - return false + if err := validateIPSetType(set.SetType); err != nil { + return err } // check port range for bitmap type set if set.SetType == BitmapPort { - if valid := validatePortRange(set.PortRange); !valid { - return false + if err := validatePortRange(set.PortRange); err != nil { + return err } } // check hash size value of ipset if set.HashSize <= 0 { - klog.Errorf("Invalid hashsize value %d, should be >0", set.HashSize) - return false + return fmt.Errorf("invalid HashSize: %d", set.HashSize) } // check max elem value of ipset if set.MaxElem <= 0 { - klog.Errorf("Invalid maxelem value %d, should be >0", set.MaxElem) - return false + return fmt.Errorf("invalid MaxElem %d", set.MaxElem) } - return true + return nil } //setIPSetDefaults sets some IPSet fields if not present to their default values. @@ -275,9 +273,8 @@ func (runner *runner) CreateSet(set *IPSet, ignoreExistErr bool) error { set.setIPSetDefaults() // Validate ipset before creating - valid := set.Validate() - if !valid { - return fmt.Errorf("error creating ipset since it's invalid") + if err := set.Validate(); err != nil { + return err } return runner.createSet(set, ignoreExistErr) } @@ -421,44 +418,39 @@ func getIPSetVersionString(exec utilexec.Interface) (string, error) { // checks if port range is valid. The begin port number is not necessarily less than // end port number - ipset util can accept it. It means both 1-100 and 100-1 are valid. -func validatePortRange(portRange string) bool { +func validatePortRange(portRange string) error { strs := strings.Split(portRange, "-") if len(strs) != 2 { - klog.Errorf("port range should be in the format of `a-b`") - return false + return fmt.Errorf("invalid PortRange: %q", portRange) } for i := range strs { num, err := strconv.Atoi(strs[i]) if err != nil { - klog.Errorf("Failed to parse %s, error: %v", strs[i], err) - return false + return fmt.Errorf("invalid PortRange: %q", portRange) } if num < 0 { - klog.Errorf("port number %d should be >=0", num) - return false + return fmt.Errorf("invalid PortRange: %q", portRange) } } - return true + return nil } // checks if the given ipset type is valid. -func validateIPSetType(set Type) bool { +func validateIPSetType(set Type) error { for _, valid := range ValidIPSetTypes { if set == valid { - return true + return nil } } - klog.Errorf("Currently supported ipset types are: %v, %s is not supported", ValidIPSetTypes, set) - return false + return fmt.Errorf("unsupported SetType: %q", set) } // checks if given hash family is supported in ipset -func validateHashFamily(family string) bool { +func validateHashFamily(family string) error { if family == ProtocolFamilyIPV4 || family == ProtocolFamilyIPV6 { - return true + return nil } - klog.Errorf("Currently supported ip set hash families are: [%s, %s], %s is not supported", ProtocolFamilyIPV4, ProtocolFamilyIPV6, family) - return false + return fmt.Errorf("unsupported HashFamily %q", family) } // IsNotFoundError returns true if the error indicates "not found". It parses diff --git a/pkg/util/ipset/ipset_test.go b/pkg/util/ipset/ipset_test.go index 030e83fa579..b0f2082b8a1 100644 --- a/pkg/util/ipset/ipset_test.go +++ b/pkg/util/ipset/ipset_test.go @@ -662,38 +662,41 @@ baz` func Test_validIPSetType(t *testing.T) { testCases := []struct { - setType Type - valid bool + setType Type + expectErr bool }{ { // case[0] - setType: Type("foo"), - valid: false, + setType: Type("foo"), + expectErr: true, }, { // case[1] - setType: HashIPPortNet, - valid: true, + setType: HashIPPortNet, + expectErr: false, }, { // case[2] - setType: HashIPPort, - valid: true, + setType: HashIPPort, + expectErr: false, }, { // case[3] - setType: HashIPPortIP, - valid: true, + setType: HashIPPortIP, + expectErr: false, }, { // case[4] - setType: BitmapPort, - valid: true, + setType: BitmapPort, + expectErr: false, }, { // case[5] - setType: Type(""), - valid: false, + setType: Type(""), + expectErr: true, }, } for i := range testCases { - valid := validateIPSetType(testCases[i].setType) - if valid != testCases[i].valid { - t.Errorf("case [%d]: unexpected mismatch, expect valid[%v], got valid[%v]", i, testCases[i].valid, valid) + err := validateIPSetType(testCases[i].setType) + if err != nil { + if !testCases[i].expectErr { + t.Errorf("case [%d]: unexpected mismatch, expect error[%v], got error[%v]", i, testCases[i].expectErr, err) + } + continue } } } @@ -701,134 +704,140 @@ func Test_validIPSetType(t *testing.T) { func Test_validatePortRange(t *testing.T) { testCases := []struct { portRange string - valid bool + expectErr bool desc string }{ { // case[0] portRange: "a-b", - valid: false, + expectErr: true, desc: "invalid port number", }, { // case[1] portRange: "1-2", - valid: true, + expectErr: false, desc: "valid", }, { // case[2] portRange: "90-1", - valid: true, + expectErr: false, desc: "ipset util can accept the input of begin port number can be less than end port number", }, { // case[3] portRange: DefaultPortRange, - valid: true, + expectErr: false, desc: "default port range is valid, of course", }, { // case[4] portRange: "12", - valid: false, + expectErr: true, desc: "a single number is invalid", }, { // case[5] portRange: "1-", - valid: false, + expectErr: true, desc: "should specify end port", }, { // case[6] portRange: "-100", - valid: false, + expectErr: true, desc: "should specify begin port", }, { // case[7] portRange: "1:100", - valid: false, + expectErr: true, desc: "delimiter should be -", }, { // case[8] portRange: "1~100", - valid: false, + expectErr: true, desc: "delimiter should be -", }, { // case[9] portRange: "1,100", - valid: false, + expectErr: true, desc: "delimiter should be -", }, { // case[10] portRange: "100-100", - valid: true, + expectErr: false, desc: "begin port number can be equal to end port number", }, { // case[11] portRange: "", - valid: false, + expectErr: true, desc: "empty string is invalid", }, { // case[12] portRange: "-1-12", - valid: false, + expectErr: true, desc: "port number can not be negative value", }, { // case[13] portRange: "-1--8", - valid: false, + expectErr: true, desc: "port number can not be negative value", }, } for i := range testCases { - valid := validatePortRange(testCases[i].portRange) - if valid != testCases[i].valid { - t.Errorf("case [%d]: unexpected mismatch, expect valid[%v], got valid[%v], desc: %s", i, testCases[i].valid, valid, testCases[i].desc) + err := validatePortRange(testCases[i].portRange) + if err != nil { + if !testCases[i].expectErr { + t.Errorf("case [%d]: unexpected mismatch, expect error[%v], got error[%v], desc: %s", i, testCases[i].expectErr, err, testCases[i].desc) + } + continue } } } func Test_validateFamily(t *testing.T) { testCases := []struct { - family string - valid bool + family string + expectErr bool }{ { // case[0] - family: "foo", - valid: false, + family: "foo", + expectErr: true, }, { // case[1] - family: ProtocolFamilyIPV4, - valid: true, + family: ProtocolFamilyIPV4, + expectErr: false, }, { // case[2] - family: ProtocolFamilyIPV6, - valid: true, + family: ProtocolFamilyIPV6, + expectErr: false, }, { // case[3] - family: "ipv4", - valid: false, + family: "ipv4", + expectErr: true, }, { // case[4] - family: "ipv6", - valid: false, + family: "ipv6", + expectErr: true, }, { // case[5] - family: "tcp", - valid: false, + family: "tcp", + expectErr: true, }, { // case[6] - family: "udp", - valid: false, + family: "udp", + expectErr: true, }, { // case[7] - family: "", - valid: false, + family: "", + expectErr: true, }, { // case[8] - family: "sctp", - valid: false, + family: "sctp", + expectErr: true, }, } for i := range testCases { - valid := validateHashFamily(testCases[i].family) - if valid != testCases[i].valid { - t.Errorf("case [%d]: unexpected mismatch, expect valid[%v], got valid[%v]", i, testCases[i].valid, valid) + err := validateHashFamily(testCases[i].family) + if err != nil { + if !testCases[i].expectErr { + t.Errorf("case [%d]: unexpected err: %v, desc: %s", i, err, testCases[i].family) + } + continue } } } @@ -888,9 +897,9 @@ func Test_validateProtocol(t *testing.T) { func TestValidateIPSet(t *testing.T) { testCases := []struct { - ipset *IPSet - valid bool - desc string + ipset *IPSet + expectErr bool + desc string }{ { // case[0] ipset: &IPSet{ @@ -900,7 +909,8 @@ func TestValidateIPSet(t *testing.T) { HashSize: 1024, MaxElem: 1024, }, - valid: true, + expectErr: false, + desc: "No Port range", }, { // case[1] ipset: &IPSet{ @@ -911,7 +921,8 @@ func TestValidateIPSet(t *testing.T) { MaxElem: 2048, PortRange: DefaultPortRange, }, - valid: true, + expectErr: false, + desc: "control case", }, { // case[2] ipset: &IPSet{ @@ -921,8 +932,8 @@ func TestValidateIPSet(t *testing.T) { HashSize: 65535, MaxElem: 2048, }, - valid: false, - desc: "should specify right port range for bitmap type set", + expectErr: true, + desc: "should specify right port range for bitmap type set", }, { // case[3] ipset: &IPSet{ @@ -932,8 +943,8 @@ func TestValidateIPSet(t *testing.T) { HashSize: 0, MaxElem: 2048, }, - valid: false, - desc: "wrong hash size number", + expectErr: true, + desc: "wrong hash size number", }, { // case[4] ipset: &IPSet{ @@ -943,8 +954,8 @@ func TestValidateIPSet(t *testing.T) { HashSize: 1024, MaxElem: -1, }, - valid: false, - desc: "wrong hash max elem number", + expectErr: true, + desc: "wrong hash max elem number", }, { // case[5] ipset: &IPSet{ @@ -954,8 +965,8 @@ func TestValidateIPSet(t *testing.T) { HashSize: 1024, MaxElem: 1024, }, - valid: false, - desc: "wrong protocol", + expectErr: true, + desc: "wrong protocol", }, { // case[6] ipset: &IPSet{ @@ -965,14 +976,17 @@ func TestValidateIPSet(t *testing.T) { HashSize: 1024, MaxElem: 1024, }, - valid: false, - desc: "wrong set type", + expectErr: true, + desc: "wrong set type", }, } for i := range testCases { - valid := testCases[i].ipset.Validate() - if valid != testCases[i].valid { - t.Errorf("case [%d]: unexpected mismatch, expect valid[%v], got valid[%v], desc: %s", i, testCases[i].valid, valid, testCases[i].desc) + err := testCases[i].ipset.Validate() + if err != nil { + if !testCases[i].expectErr { + t.Errorf("case [%d]: unexpected mismatch, expect error[%v], got error[%v], desc: %s", i, testCases[i].expectErr, err, testCases[i].desc) + } + continue } } }