diff --git a/pkg/api/service/annotations.go b/pkg/api/service/annotations.go new file mode 100644 index 00000000000..9d57fa4c208 --- /dev/null +++ b/pkg/api/service/annotations.go @@ -0,0 +1,28 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package service + +const ( + // AnnotationLoadBalancerSourceRangesKey is the key of the annotation on a service to set allowed ingress ranges on their LoadBalancers + // + // It should be a comma-separated list of CIDRs, e.g. `0.0.0.0/0` to + // allow full access (the default) or `18.0.0.0/8,56.0.0.0/8` to allow + // access only from the CIDRs currently allocated to MIT & the USPS. + // + // Not all cloud providers support this annotation, though AWS & GCE do. + AnnotationLoadBalancerSourceRangesKey = "service.beta.kubernetes.io/load-balancer-source-ranges" +) diff --git a/pkg/api/service/util.go b/pkg/api/service/util.go new file mode 100644 index 00000000000..a77e5b9c70b --- /dev/null +++ b/pkg/api/service/util.go @@ -0,0 +1,54 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package service + +import ( + "fmt" + "strings" + + netsets "k8s.io/kubernetes/pkg/util/net/sets" +) + +const ( + defaultLoadBalancerSourceRanges = "0.0.0.0/0" +) + +// IsAllowAll checks whether the netsets.IPNet allows traffic from 0.0.0.0/0 +func IsAllowAll(ipnets netsets.IPNet) bool { + for _, s := range ipnets.StringSlice() { + if s == "0.0.0.0/0" { + return true + } + } + return false +} + +// GetLoadBalancerSourceRanges verifies and parses the AnnotationLoadBalancerSourceRangesKey annotation from a service, +// extracting the source ranges to allow, and if not present returns a default (allow-all) value. +func GetLoadBalancerSourceRanges(annotations map[string]string) (netsets.IPNet, error) { + val := annotations[AnnotationLoadBalancerSourceRangesKey] + val = strings.TrimSpace(val) + if val == "" { + val = defaultLoadBalancerSourceRanges + } + specs := strings.Split(val, ",") + ipnets, err := netsets.ParseIPNets(specs...) + if err != nil { + return nil, fmt.Errorf("Service annotation %s:%s is not valid. Expecting a comma-separated list of source IP ranges. For example, 10.0.0.0/24,192.168.2.0/24", AnnotationLoadBalancerSourceRangesKey, val) + } + return ipnets, nil +} diff --git a/pkg/api/service/util_test.go b/pkg/api/service/util_test.go new file mode 100644 index 00000000000..c77d4f25906 --- /dev/null +++ b/pkg/api/service/util_test.go @@ -0,0 +1,92 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package service + +import ( + "testing" + + netsets "k8s.io/kubernetes/pkg/util/net/sets" +) + +func TestGetLoadBalancerSourceRanges(t *testing.T) { + checkError := func(v string) { + annotations := make(map[string]string) + annotations[AnnotationLoadBalancerSourceRangesKey] = v + _, err := GetLoadBalancerSourceRanges(annotations) + if err == nil { + t.Errorf("Expected error parsing: %q", v) + } + } + checkError("10.0.0.1/33") + checkError("foo.bar") + checkError("10.0.0.1/32,*") + checkError("10.0.0.1/32,") + checkError("10.0.0.1/32, ") + checkError("10.0.0.1") + + checkOK := func(v string) netsets.IPNet { + annotations := make(map[string]string) + annotations[AnnotationLoadBalancerSourceRangesKey] = v + cidrs, err := GetLoadBalancerSourceRanges(annotations) + if err != nil { + t.Errorf("Unexpected error parsing: %q", v) + } + return cidrs + } + cidrs := checkOK("192.168.0.1/32") + if len(cidrs) != 1 { + t.Errorf("Expected exactly one CIDR: %v", cidrs.StringSlice()) + } + cidrs = checkOK("192.168.0.1/32,192.168.0.1/32") + if len(cidrs) != 1 { + t.Errorf("Expected exactly one CIDR (after de-dup): %v", cidrs.StringSlice()) + } + cidrs = checkOK("192.168.0.1/32,192.168.0.2/32") + if len(cidrs) != 2 { + t.Errorf("Expected two CIDRs: %v", cidrs.StringSlice()) + } + cidrs = checkOK(" 192.168.0.1/32 , 192.168.0.2/32 ") + if len(cidrs) != 2 { + t.Errorf("Expected two CIDRs: %v", cidrs.StringSlice()) + } + cidrs = checkOK("") + if len(cidrs) != 1 { + t.Errorf("Expected exactly one CIDR: %v", cidrs.StringSlice()) + } + if !IsAllowAll(cidrs) { + t.Errorf("Expected default to be allow-all: %v", cidrs.StringSlice()) + } +} + +func TestAllowAll(t *testing.T) { + checkAllowAll := func(allowAll bool, cidrs ...string) { + ipnets, err := netsets.ParseIPNets(cidrs...) + if err != nil { + t.Errorf("Unexpected error parsing cidrs: %v", cidrs) + } + if allowAll != IsAllowAll(ipnets) { + t.Errorf("IsAllowAll did not return expected value for %v", cidrs) + } + } + checkAllowAll(false, "10.0.0.1/32") + checkAllowAll(false, "10.0.0.1/32", "10.0.0.2/32") + checkAllowAll(false, "10.0.0.1/32", "10.0.0.1/32") + + checkAllowAll(true, "0.0.0.0/0") + checkAllowAll(true, "192.168.0.0/0") + checkAllowAll(true, "192.168.0.1/32", "0.0.0.0/0") +} diff --git a/pkg/api/validation/validation.go b/pkg/api/validation/validation.go index 5f0d8767be1..c2bffe7366e 100644 --- a/pkg/api/validation/validation.go +++ b/pkg/api/validation/validation.go @@ -29,6 +29,7 @@ import ( "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/api/resource" + apiservice "k8s.io/kubernetes/pkg/api/service" "k8s.io/kubernetes/pkg/capabilities" "k8s.io/kubernetes/pkg/labels" "k8s.io/kubernetes/pkg/util/intstr" @@ -1736,6 +1737,12 @@ func ValidateService(service *api.Service) field.ErrorList { nodePorts[key] = true } + _, err := apiservice.GetLoadBalancerSourceRanges(service.Annotations) + if err != nil { + v := service.Annotations[apiservice.AnnotationLoadBalancerSourceRangesKey] + allErrs = append(allErrs, field.Invalid(field.NewPath("metadata", "annotations").Key(apiservice.AnnotationLoadBalancerSourceRangesKey), v, "must be a comma separated list of CIDRs e.g. 192.168.0.0/16,10.0.0.0/8")) + } + return allErrs } diff --git a/pkg/api/validation/validation_test.go b/pkg/api/validation/validation_test.go index daa43655165..252c9adda45 100644 --- a/pkg/api/validation/validation_test.go +++ b/pkg/api/validation/validation_test.go @@ -25,6 +25,7 @@ import ( "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/api/resource" + "k8s.io/kubernetes/pkg/api/service" "k8s.io/kubernetes/pkg/api/testapi" "k8s.io/kubernetes/pkg/api/unversioned" "k8s.io/kubernetes/pkg/capabilities" @@ -2877,6 +2878,34 @@ func TestValidateService(t *testing.T) { }, numErrs: 1, }, + { + name: "valid LoadBalancer source range annotation", + tweakSvc: func(s *api.Service) { + s.Annotations[service.AnnotationLoadBalancerSourceRangesKey] = "1.2.3.4/8, 5.6.7.8/16" + }, + numErrs: 0, + }, + { + name: "empty LoadBalancer source range annotation", + tweakSvc: func(s *api.Service) { + s.Annotations[service.AnnotationLoadBalancerSourceRangesKey] = "" + }, + numErrs: 0, + }, + { + name: "invalid LoadBalancer source range annotation (hostname)", + tweakSvc: func(s *api.Service) { + s.Annotations[service.AnnotationLoadBalancerSourceRangesKey] = "foo.bar" + }, + numErrs: 1, + }, + { + name: "invalid LoadBalancer source range annotation (invalid CIDR)", + tweakSvc: func(s *api.Service) { + s.Annotations[service.AnnotationLoadBalancerSourceRangesKey] = "1.2.3.4/33" + }, + numErrs: 1, + }, } for _, tc := range testCases { diff --git a/pkg/cloudprovider/cloud.go b/pkg/cloudprovider/cloud.go index d7062be5acf..dd120735e42 100644 --- a/pkg/cloudprovider/cloud.go +++ b/pkg/cloudprovider/cloud.go @@ -26,20 +26,6 @@ import ( "k8s.io/kubernetes/pkg/types" ) -const ( - // The value of a LBAnnotationAllowSourceRange annotation determines - // the source IP ranges to allow to access a service exposed as - // type=LoadBalancer (when accesssed through the LoadBalancer created - // by the cloud provider). - // - // It should be a comma-separated list of CIDRs, e.g. `0.0.0.0/0` to - // allow full access (the default) or `18.0.0.0/8,56.0.0.0/8` to allow - // access only from the CIDRs currently allocated to MIT & the USPS. - // - // Not all cloud providers support this annotation, though AWS & GCE do. - LBAnnotationAllowSourceRange = "service.beta.kubernetes.io/load-balancer-source-ranges" -) - // Interface is an abstract, pluggable interface for cloud providers. type Interface interface { // LoadBalancer returns a balancer interface. Also returns true if the interface is supported, false otherwise. diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 8e7ed480195..f73c8dcd94c 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -47,6 +47,7 @@ import ( "k8s.io/kubernetes/pkg/util/sets" "github.com/golang/glog" + "k8s.io/kubernetes/pkg/api/service" "k8s.io/kubernetes/pkg/api/unversioned" ) @@ -1979,7 +1980,7 @@ func (s *AWSCloud) EnsureLoadBalancer(name, region string, publicIP net.IP, port return nil, err } - sourceRanges, err := cloudprovider.GetSourceRangeAnnotations(annotations) + sourceRanges, err := service.GetLoadBalancerSourceRanges(annotations) if err != nil { return nil, err } diff --git a/pkg/cloudprovider/providers/gce/gce.go b/pkg/cloudprovider/providers/gce/gce.go index 5dda7e74659..bf458ce8e21 100644 --- a/pkg/cloudprovider/providers/gce/gce.go +++ b/pkg/cloudprovider/providers/gce/gce.go @@ -30,10 +30,12 @@ import ( "time" "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/api/service" "k8s.io/kubernetes/pkg/api/unversioned" "k8s.io/kubernetes/pkg/cloudprovider" "k8s.io/kubernetes/pkg/types" utilerrors "k8s.io/kubernetes/pkg/util/errors" + netsets "k8s.io/kubernetes/pkg/util/net/sets" "k8s.io/kubernetes/pkg/util/sets" "k8s.io/kubernetes/pkg/util/wait" @@ -577,7 +579,7 @@ func (gce *GCECloud) EnsureLoadBalancer(name, region string, requestedIP net.IP, // is because the forwarding rule is used as the indicator that the load // balancer is fully created - it's what getLoadBalancer checks for. // Check if user specified the allow source range - sourceRanges, err := cloudprovider.GetSourceRangeAnnotations(annotations) + sourceRanges, err := service.GetLoadBalancerSourceRanges(annotations) if err != nil { return nil, err } @@ -740,7 +742,7 @@ func translateAffinityType(affinityType api.ServiceAffinity) string { } } -func (gce *GCECloud) firewallNeedsUpdate(name, serviceName, region, ipAddress string, ports []*api.ServicePort, sourceRanges cloudprovider.IPNetSet) (exists bool, needsUpdate bool, err error) { +func (gce *GCECloud) firewallNeedsUpdate(name, serviceName, region, ipAddress string, ports []*api.ServicePort, sourceRanges netsets.IPNet) (exists bool, needsUpdate bool, err error) { fw, err := gce.service.Firewalls.Get(gce.projectID, makeFirewallName(name)).Do() if err != nil { if isHTTPErrorCode(err, http.StatusNotFound) { @@ -764,7 +766,7 @@ func (gce *GCECloud) firewallNeedsUpdate(name, serviceName, region, ipAddress st } // The service controller already verified that the protocol matches on all ports, no need to check. - actualSourceRanges, err := cloudprovider.ParseIPNetSet(fw.SourceRanges) + actualSourceRanges, err := netsets.ParseIPNets(fw.SourceRanges...) if err != nil { // This really shouldn't happen... GCE has returned something unexpected glog.Warningf("Error parsing firewall SourceRanges: %v", fw.SourceRanges) @@ -852,7 +854,7 @@ func (gce *GCECloud) createTargetPool(name, serviceName, region string, hosts [] return nil } -func (gce *GCECloud) createFirewall(name, region, desc string, sourceRanges cloudprovider.IPNetSet, ports []*api.ServicePort, hosts []*gceInstance) error { +func (gce *GCECloud) createFirewall(name, region, desc string, sourceRanges netsets.IPNet, ports []*api.ServicePort, hosts []*gceInstance) error { firewall, err := gce.firewallObject(name, region, desc, sourceRanges, ports, hosts) if err != nil { return err @@ -870,7 +872,7 @@ func (gce *GCECloud) createFirewall(name, region, desc string, sourceRanges clou return nil } -func (gce *GCECloud) updateFirewall(name, region, desc string, sourceRanges cloudprovider.IPNetSet, ports []*api.ServicePort, hosts []*gceInstance) error { +func (gce *GCECloud) updateFirewall(name, region, desc string, sourceRanges netsets.IPNet, ports []*api.ServicePort, hosts []*gceInstance) error { firewall, err := gce.firewallObject(name, region, desc, sourceRanges, ports, hosts) if err != nil { return err @@ -888,7 +890,7 @@ func (gce *GCECloud) updateFirewall(name, region, desc string, sourceRanges clou return nil } -func (gce *GCECloud) firewallObject(name, region, desc string, sourceRanges cloudprovider.IPNetSet, ports []*api.ServicePort, hosts []*gceInstance) (*compute.Firewall, error) { +func (gce *GCECloud) firewallObject(name, region, desc string, sourceRanges netsets.IPNet, ports []*api.ServicePort, hosts []*gceInstance) (*compute.Firewall, error) { allowedPorts := make([]string, len(ports)) for ix := range ports { allowedPorts[ix] = strconv.Itoa(ports[ix].Port) @@ -1206,7 +1208,7 @@ func (gce *GCECloud) GetFirewall(name string) (*compute.Firewall, error) { } // CreateFirewall creates the given firewall rule. -func (gce *GCECloud) CreateFirewall(name, desc string, sourceRanges cloudprovider.IPNetSet, ports []int64, hostNames []string) error { +func (gce *GCECloud) CreateFirewall(name, desc string, sourceRanges netsets.IPNet, ports []int64, hostNames []string) error { region, err := GetGCERegion(gce.localZone) if err != nil { return err @@ -1235,7 +1237,7 @@ func (gce *GCECloud) DeleteFirewall(name string) error { // UpdateFirewall applies the given firewall rule as an update to an existing // firewall rule with the same name. -func (gce *GCECloud) UpdateFirewall(name, desc string, sourceRanges cloudprovider.IPNetSet, ports []int64, hostNames []string) error { +func (gce *GCECloud) UpdateFirewall(name, desc string, sourceRanges netsets.IPNet, ports []int64, hostNames []string) error { region, err := GetGCERegion(gce.localZone) if err != nil { return err diff --git a/pkg/cloudprovider/providers/openstack/openstack.go b/pkg/cloudprovider/providers/openstack/openstack.go index a8fffac2c3e..57378674cd7 100644 --- a/pkg/cloudprovider/providers/openstack/openstack.go +++ b/pkg/cloudprovider/providers/openstack/openstack.go @@ -45,6 +45,7 @@ import ( "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/api/resource" + "k8s.io/kubernetes/pkg/api/service" "k8s.io/kubernetes/pkg/cloudprovider" "k8s.io/kubernetes/pkg/types" ) @@ -685,12 +686,12 @@ func (lb *LoadBalancer) EnsureLoadBalancer(name, region string, loadBalancerIP n return nil, fmt.Errorf("unsupported load balancer affinity: %v", affinity) } - sourceRanges, err := cloudprovider.GetSourceRangeAnnotations(annotations) + sourceRanges, err := service.GetLoadBalancerSourceRanges(annotations) if err != nil { return nil, err } - if !cloudprovider.IsAllowAll(sourceRanges) { + if !service.IsAllowAll(sourceRanges) { return nil, fmt.Errorf("Source range restrictions are not supported for openstack load balancers") } diff --git a/pkg/cloudprovider/utils.go b/pkg/cloudprovider/utils.go deleted file mode 100644 index 30ff240196f..00000000000 --- a/pkg/cloudprovider/utils.go +++ /dev/null @@ -1,99 +0,0 @@ -/* -Copyright 2016 The Kubernetes Authors All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package cloudprovider - -import ( - "fmt" - "net" - "strings" -) - -const ( - defaultLBSourceRange = "0.0.0.0/0" -) - -type IPNetSet map[string]*net.IPNet - -func ParseIPNetSet(specs []string) (IPNetSet, error) { - ipnetset := make(IPNetSet) - for _, spec := range specs { - spec = strings.TrimSpace(spec) - _, ipnet, err := net.ParseCIDR(spec) - if err != nil { - return nil, err - } - k := ipnet.String() // In case of normalization - ipnetset[k] = ipnet - } - return ipnetset, nil -} - -// StringSlice returns a []string with the String representation of each element in the set. -// Order is undefined. -func (s IPNetSet) StringSlice() []string { - a := make([]string, 0, len(s)) - for k := range s { - a = append(a, k) - } - return a -} - -// Equal checks if two IPNetSets are equal (ignoring order) -func (l IPNetSet) Equal(r IPNetSet) bool { - if len(l) != len(r) { - return false - } - - for k := range l { - _, found := r[k] - if !found { - return false - } - } - return true -} - -// Len returns the size of the set. -func (s IPNetSet) Len() int { - return len(s) -} - -// GetSourceRangeAnnotations verifies and parses the LBAnnotationAllowSourceRange annotation from a service, -// extracting the source ranges to allow, and if not present returns a default (allow-all) value. -func GetSourceRangeAnnotations(annotation map[string]string) (IPNetSet, error) { - val := annotation[LBAnnotationAllowSourceRange] - val = strings.TrimSpace(val) - if val == "" { - val = defaultLBSourceRange - } - specs := strings.Split(val, ",") - ipnets, err := ParseIPNetSet(specs) - if err != nil { - return nil, fmt.Errorf("Service annotation %s:%s is not valid. Expecting a comma-separated list of source IP ranges. For example, 10.0.0.0/24,192.168.2.0/24", LBAnnotationAllowSourceRange, val) - } - return ipnets, nil -} - -// IsAllowAll checks whether the IPNetSet contains the default allow-all policy -func IsAllowAll(ipnets IPNetSet) bool { - for _, s := range ipnets.StringSlice() { - if s == "0.0.0.0/0" { - return true - } - } - return false -} diff --git a/pkg/util/net/sets/README.md b/pkg/util/net/sets/README.md new file mode 100644 index 00000000000..b0f238a26f6 --- /dev/null +++ b/pkg/util/net/sets/README.md @@ -0,0 +1,17 @@ +This package contains hand-coded set implementations that should be similar to +the autogenerated ones in `pkg/util/sets`. + +We can't simply use net.IPNet as a map-key in Go (because it contains a +`[]byte`). + +We could use the same workaround we use here (a string representation as the +key) to autogenerate sets. If we do that, or decide on an alternate approach, +we should replace the implementations in this package with the autogenerated +versions. + +It is expected that callers will alias this import as `netsets` +i.e. `import netsets "k8s.io/kubernetes/pkg/util/net/sets"` + + + +[![Analytics](https://kubernetes-site.appspot.com/UA-36037335-10/GitHub/pkg/util/net/sets/README.md?pixel)]() diff --git a/pkg/util/net/sets/ipnet.go b/pkg/util/net/sets/ipnet.go new file mode 100644 index 00000000000..db117f63ec5 --- /dev/null +++ b/pkg/util/net/sets/ipnet.go @@ -0,0 +1,119 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +import ( + "net" + "strings" +) + +type IPNet map[string]*net.IPNet + +func ParseIPNets(specs ...string) (IPNet, error) { + ipnetset := make(IPNet) + for _, spec := range specs { + spec = strings.TrimSpace(spec) + _, ipnet, err := net.ParseCIDR(spec) + if err != nil { + return nil, err + } + k := ipnet.String() // In case of normalization + ipnetset[k] = ipnet + } + return ipnetset, nil +} + +// Insert adds items to the set. +func (s IPNet) Insert(items ...*net.IPNet) { + for _, item := range items { + s[item.String()] = item + } +} + +// Delete removes all items from the set. +func (s IPNet) Delete(items ...*net.IPNet) { + for _, item := range items { + delete(s, item.String()) + } +} + +// Has returns true if and only if item is contained in the set. +func (s IPNet) Has(item *net.IPNet) bool { + _, contained := s[item.String()] + return contained +} + +// HasAll returns true if and only if all items are contained in the set. +func (s IPNet) HasAll(items ...*net.IPNet) bool { + for _, item := range items { + if !s.Has(item) { + return false + } + } + return true +} + +// Difference returns a set of objects that are not in s2 +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (s IPNet) Difference(s2 IPNet) IPNet { + result := make(IPNet) + for k, i := range s { + _, found := s2[k] + if found { + continue + } + result[k] = i + } + return result +} + +// StringSlice returns a []string with the String representation of each element in the set. +// Order is undefined. +func (s IPNet) StringSlice() []string { + a := make([]string, 0, len(s)) + for k := range s { + a = append(a, k) + } + return a +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s1 IPNet) IsSuperset(s2 IPNet) bool { + for k := range s2 { + _, found := s1[k] + if !found { + return false + } + } + return true +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (s1 IPNet) Equal(s2 IPNet) bool { + return len(s1) == len(s2) && s1.IsSuperset(s2) +} + +// Len returns the size of the set. +func (s IPNet) Len() int { + return len(s) +} diff --git a/pkg/util/net/sets/ipnet_test.go b/pkg/util/net/sets/ipnet_test.go new file mode 100644 index 00000000000..0223d1651cd --- /dev/null +++ b/pkg/util/net/sets/ipnet_test.go @@ -0,0 +1,155 @@ +/* +Copyright 2014 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sets + +import ( + "net" + "reflect" + "sort" + "testing" +) + +func parseIPNet(s string) *net.IPNet { + _, net, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return net +} + +func TestIPNets(t *testing.T) { + s := IPNet{} + s2 := IPNet{} + if len(s) != 0 { + t.Errorf("Expected len=0: %d", len(s)) + } + a := parseIPNet("1.0.0.0/8") + b := parseIPNet("2.0.0.0/8") + c := parseIPNet("3.0.0.0/8") + d := parseIPNet("4.0.0.0/8") + + s.Insert(a, b) + if len(s) != 2 { + t.Errorf("Expected len=2: %d", len(s)) + } + s.Insert(c) + if s.Has(d) { + t.Errorf("Unexpected contents: %#v", s) + } + if !s.Has(a) { + t.Errorf("Missing contents: %#v", s) + } + s.Delete(a) + if s.Has(a) { + t.Errorf("Unexpected contents: %#v", s) + } + s.Insert(a) + if s.HasAll(a, b, d) { + t.Errorf("Unexpected contents: %#v", s) + } + if !s.HasAll(a, b) { + t.Errorf("Missing contents: %#v", s) + } + s2.Insert(a, b, d) + if s.IsSuperset(s2) { + t.Errorf("Unexpected contents: %#v", s) + } + s2.Delete(d) + if !s.IsSuperset(s2) { + t.Errorf("Missing contents: %#v", s) + } +} + +func TestIPNetSetDeleteMultiples(t *testing.T) { + s := IPNet{} + a := parseIPNet("1.0.0.0/8") + b := parseIPNet("2.0.0.0/8") + c := parseIPNet("3.0.0.0/8") + + s.Insert(a, b, c) + if len(s) != 3 { + t.Errorf("Expected len=3: %d", len(s)) + } + + s.Delete(a, c) + if len(s) != 1 { + t.Errorf("Expected len=1: %d", len(s)) + } + if s.Has(a) { + t.Errorf("Unexpected contents: %#v", s) + } + if s.Has(c) { + t.Errorf("Unexpected contents: %#v", s) + } + if !s.Has(b) { + t.Errorf("Missing contents: %#v", s) + } +} + +func TestNewIPSet(t *testing.T) { + s, err := ParseIPNets("1.0.0.0/8", "2.0.0.0/8", "3.0.0.0/8") + if err != nil { + t.Errorf("error parsing IPNets: %v", err) + } + if len(s) != 3 { + t.Errorf("Expected len=3: %d", len(s)) + } + a := parseIPNet("1.0.0.0/8") + b := parseIPNet("2.0.0.0/8") + c := parseIPNet("3.0.0.0/8") + + if !s.Has(a) || !s.Has(b) || !s.Has(c) { + t.Errorf("Unexpected contents: %#v", s) + } +} + +func TestIPNetSetDifference(t *testing.T) { + l, err := ParseIPNets("1.0.0.0/8", "2.0.0.0/8", "3.0.0.0/8") + if err != nil { + t.Errorf("error parsing IPNets: %v", err) + } + r, err := ParseIPNets("1.0.0.0/8", "2.0.0.0/8", "4.0.0.0/8", "5.0.0.0/8") + if err != nil { + t.Errorf("error parsing IPNets: %v", err) + } + c := l.Difference(r) + d := r.Difference(l) + if len(c) != 1 { + t.Errorf("Expected len=1: %d", len(c)) + } + if !c.Has(parseIPNet("3.0.0.0/8")) { + t.Errorf("Unexpected contents: %#v", c) + } + if len(d) != 2 { + t.Errorf("Expected len=2: %d", len(d)) + } + if !d.Has(parseIPNet("4.0.0.0/8")) || !d.Has(parseIPNet("5.0.0.0/8")) { + t.Errorf("Unexpected contents: %#v", d) + } +} + +func TestIPNetSetList(t *testing.T) { + s, err := ParseIPNets("3.0.0.0/8", "1.0.0.0/8", "2.0.0.0/8") + if err != nil { + t.Errorf("error parsing IPNets: %v", err) + } + l := s.StringSlice() + sort.Strings(l) + if !reflect.DeepEqual(l, []string{"1.0.0.0/8", "2.0.0.0/8", "3.0.0.0/8"}) { + t.Errorf("List gave unexpected result: %#v", l) + } +}