From 9f37096c38b65947d1cf7eef6c1e3722bb88fc97 Mon Sep 17 00:00:00 2001 From: Lars Ekman Date: Sat, 21 Aug 2021 08:44:27 +0200 Subject: [PATCH] Kube-proxy/ipvs; Use go "net" lib to get nodeIPs The nodeIPs to be used for nodePorts were collected using netlink which was unnecessary complex and caused se #93858 --- pkg/proxy/ipvs/netlink.go | 11 ++- pkg/proxy/ipvs/netlink_linux.go | 103 ++++++++--------------- pkg/proxy/ipvs/netlink_unsupported.go | 29 +++++-- pkg/proxy/ipvs/proxier.go | 42 +++++---- pkg/proxy/ipvs/proxier_test.go | 117 ++++++++++++++++++++++++-- pkg/proxy/ipvs/testing/fake.go | 39 ++++++--- pkg/proxy/ipvs/testing/fake_test.go | 12 +-- pkg/proxy/util/utils.go | 21 +++++ pkg/proxy/util/utils_test.go | 116 +++++++++++++++++++++++++ 9 files changed, 369 insertions(+), 121 deletions(-) diff --git a/pkg/proxy/ipvs/netlink.go b/pkg/proxy/ipvs/netlink.go index aa4dbc5b6db..2d77bb32044 100644 --- a/pkg/proxy/ipvs/netlink.go +++ b/pkg/proxy/ipvs/netlink.go @@ -32,7 +32,12 @@ type NetLinkHandle interface { DeleteDummyDevice(devName string) error // ListBindAddress will list all IP addresses which are bound in a given interface ListBindAddress(devName string) ([]string, error) - // GetLocalAddresses returns all unique local type IP addresses based on specified device and filter device - // If device is not specified, it will list all unique local type addresses except filter device addresses - GetLocalAddresses(dev, filterDev string) (sets.String, error) + // GetAllLocalAddresses return all local addresses on the node. + // Only the addresses of the current family are returned. + // IPv6 link-local and loopback addresses are excluded. + GetAllLocalAddresses() (sets.String, error) + // GetLocalAddresses return all local addresses for an interface. + // Only the addresses of the current family are returned. + // IPv6 link-local and loopback addresses are excluded. + GetLocalAddresses(dev string) (sets.String, error) } diff --git a/pkg/proxy/ipvs/netlink_linux.go b/pkg/proxy/ipvs/netlink_linux.go index ceb643bb76e..ebaae4b4bee 100644 --- a/pkg/proxy/ipvs/netlink_linux.go +++ b/pkg/proxy/ipvs/netlink_linux.go @@ -21,8 +21,10 @@ package ipvs import ( "fmt" + "net" "k8s.io/apimachinery/pkg/util/sets" + utilproxy "k8s.io/kubernetes/pkg/proxy/util" netutils "k8s.io/utils/net" "github.com/vishvananda/netlink" @@ -124,72 +126,41 @@ func (h *netlinkHandle) ListBindAddress(devName string) ([]string, error) { return ips, nil } -// GetLocalAddresses lists all LOCAL type IP addresses from host based on filter device. -// If dev is not specified, it's equivalent to exec: -// $ ip route show table local type local proto kernel -// 10.0.0.1 dev kube-ipvs0 scope host src 10.0.0.1 -// 10.0.0.10 dev kube-ipvs0 scope host src 10.0.0.10 -// 10.0.0.252 dev kube-ipvs0 scope host src 10.0.0.252 -// 100.106.89.164 dev eth0 scope host src 100.106.89.164 -// 127.0.0.0/8 dev lo scope host src 127.0.0.1 -// 127.0.0.1 dev lo scope host src 127.0.0.1 -// 172.17.0.1 dev docker0 scope host src 172.17.0.1 -// 192.168.122.1 dev virbr0 scope host src 192.168.122.1 -// Then cut the unique src IP fields, -// --> result set: [10.0.0.1, 10.0.0.10, 10.0.0.252, 100.106.89.164, 127.0.0.1, 172.17.0.1, 192.168.122.1] - -// If dev is specified, it's equivalent to exec: -// $ ip route show table local type local proto kernel dev kube-ipvs0 -// 10.0.0.1 scope host src 10.0.0.1 -// 10.0.0.10 scope host src 10.0.0.10 -// Then cut the unique src IP fields, -// --> result set: [10.0.0.1, 10.0.0.10] - -// If filterDev is specified, the result will discard route of specified device and cut src from other routes. -func (h *netlinkHandle) GetLocalAddresses(dev, filterDev string) (sets.String, error) { - chosenLinkIndex, filterLinkIndex := -1, -1 - if dev != "" { - link, err := h.LinkByName(dev) - if err != nil { - return nil, fmt.Errorf("error get device %s, err: %v", dev, err) - } - chosenLinkIndex = link.Attrs().Index - } else if filterDev != "" { - link, err := h.LinkByName(filterDev) - if err != nil { - return nil, fmt.Errorf("error get filter device %s, err: %v", filterDev, err) - } - filterLinkIndex = link.Attrs().Index - } - - routeFilter := &netlink.Route{ - Table: unix.RT_TABLE_LOCAL, - Type: unix.RTN_LOCAL, - Protocol: unix.RTPROT_KERNEL, - } - filterMask := netlink.RT_FILTER_TABLE | netlink.RT_FILTER_TYPE | netlink.RT_FILTER_PROTOCOL - - // find chosen device - if chosenLinkIndex != -1 { - routeFilter.LinkIndex = chosenLinkIndex - filterMask |= netlink.RT_FILTER_OIF - } - routes, err := h.RouteListFiltered(netlink.FAMILY_ALL, routeFilter, filterMask) +// GetAllLocalAddresses return all local addresses on the node. +// Only the addresses of the current family are returned. +// IPv6 link-local and loopback addresses are excluded. +func (h *netlinkHandle) GetAllLocalAddresses() (sets.String, error) { + addr, err := net.InterfaceAddrs() if err != nil { - return nil, fmt.Errorf("error list route table, err: %v", err) + return nil, fmt.Errorf("Could not get addresses: %v", err) } - res := sets.NewString() - for _, route := range routes { - if route.LinkIndex == filterLinkIndex { - continue - } - if h.isIPv6 { - if route.Dst.IP.To4() == nil && !route.Dst.IP.IsLinkLocalUnicast() { - res.Insert(route.Dst.IP.String()) - } - } else if route.Src != nil { - res.Insert(route.Src.String()) - } - } - return res, nil + return utilproxy.AddressSet(h.isValidForSet, addr), nil +} + +// GetLocalAddresses return all local addresses for an interface. +// Only the addresses of the current family are returned. +// IPv6 link-local and loopback addresses are excluded. +func (h *netlinkHandle) GetLocalAddresses(dev string) (sets.String, error) { + ifi, err := net.InterfaceByName(dev) + if err != nil { + return nil, fmt.Errorf("Could not get interface %s: %v", dev, err) + } + addr, err := ifi.Addrs() + if err != nil { + return nil, fmt.Errorf("Can't get addresses from %s: %v", ifi.Name, err) + } + return utilproxy.AddressSet(h.isValidForSet, addr), nil +} + +func (h *netlinkHandle) isValidForSet(ip net.IP) bool { + if h.isIPv6 != netutils.IsIPv6(ip) { + return false + } + if h.isIPv6 && ip.IsLinkLocalUnicast() { + return false + } + if ip.IsLoopback() { + return false + } + return true } diff --git a/pkg/proxy/ipvs/netlink_unsupported.go b/pkg/proxy/ipvs/netlink_unsupported.go index 06add05fa2f..eec2f5bada7 100644 --- a/pkg/proxy/ipvs/netlink_unsupported.go +++ b/pkg/proxy/ipvs/netlink_unsupported.go @@ -21,44 +21,57 @@ package ipvs import ( "fmt" + "net" "k8s.io/apimachinery/pkg/util/sets" ) -type emptyHandle struct { +// The type must match the one in proxier_test.go +type netlinkHandle struct { + isIPv6 bool } // NewNetLinkHandle will create an EmptyHandle func NewNetLinkHandle(ipv6 bool) NetLinkHandle { - return &emptyHandle{} + return &netlinkHandle{} } // EnsureAddressBind checks if address is bound to the interface and, if not, binds it. If the address is already bound, return true. -func (h *emptyHandle) EnsureAddressBind(address, devName string) (exist bool, err error) { +func (h *netlinkHandle) EnsureAddressBind(address, devName string) (exist bool, err error) { return false, fmt.Errorf("netlink not supported for this platform") } // UnbindAddress unbind address from the interface -func (h *emptyHandle) UnbindAddress(address, devName string) error { +func (h *netlinkHandle) UnbindAddress(address, devName string) error { return fmt.Errorf("netlink not supported for this platform") } // EnsureDummyDevice is part of interface -func (h *emptyHandle) EnsureDummyDevice(devName string) (bool, error) { +func (h *netlinkHandle) EnsureDummyDevice(devName string) (bool, error) { return false, fmt.Errorf("netlink is not supported in this platform") } // DeleteDummyDevice is part of interface. -func (h *emptyHandle) DeleteDummyDevice(devName string) error { +func (h *netlinkHandle) DeleteDummyDevice(devName string) error { return fmt.Errorf("netlink is not supported in this platform") } // ListBindAddress is part of interface. -func (h *emptyHandle) ListBindAddress(devName string) ([]string, error) { +func (h *netlinkHandle) ListBindAddress(devName string) ([]string, error) { + return nil, fmt.Errorf("netlink is not supported in this platform") +} + +// GetAllLocalAddresses is part of interface. +func (h *netlinkHandle) GetAllLocalAddresses() (sets.String, error) { return nil, fmt.Errorf("netlink is not supported in this platform") } // GetLocalAddresses is part of interface. -func (h *emptyHandle) GetLocalAddresses(dev, filterDev string) (sets.String, error) { +func (h *netlinkHandle) GetLocalAddresses(dev string) (sets.String, error) { return nil, fmt.Errorf("netlink is not supported in this platform") } + +// Must match the one in proxier_test.go +func (h *netlinkHandle) isValidForSet(ip net.IP) bool { + return false +} diff --git a/pkg/proxy/ipvs/proxier.go b/pkg/proxy/ipvs/proxier.go index 9211bfda1ba..0178f188325 100644 --- a/pkg/proxy/ipvs/proxier.go +++ b/pkg/proxy/ipvs/proxier.go @@ -289,33 +289,31 @@ type realIPGetter struct { nl NetLinkHandle } -// NodeIPs returns all LOCAL type IP addresses from host which are taken as the Node IPs of NodePort service. -// It will list source IP exists in local route table with `kernel` protocol type, and filter out IPVS proxier -// created dummy device `kube-ipvs0` For example, -// $ ip route show table local type local proto kernel -// 10.0.0.1 dev kube-ipvs0 scope host src 10.0.0.1 -// 10.0.0.10 dev kube-ipvs0 scope host src 10.0.0.10 -// 10.0.0.252 dev kube-ipvs0 scope host src 10.0.0.252 -// 100.106.89.164 dev eth0 scope host src 100.106.89.164 -// 127.0.0.0/8 dev lo scope host src 127.0.0.1 -// 127.0.0.1 dev lo scope host src 127.0.0.1 -// 172.17.0.1 dev docker0 scope host src 172.17.0.1 -// 192.168.122.1 dev virbr0 scope host src 192.168.122.1 -// Then filter out dev==kube-ipvs0, and cut the unique src IP fields, -// Node IP set: [100.106.89.164, 172.17.0.1, 192.168.122.1] -// Note that loopback addresses are excluded. +// NodeIPs returns all LOCAL type IP addresses from host which are +// taken as the Node IPs of NodePort service. Filtered addresses: +// +// * Loopback addresses +// * Addresses of the "other" family (not handled by this proxier instance) +// * Link-local IPv6 addresses +// * Addresses on the created dummy device `kube-ipvs0` +// func (r *realIPGetter) NodeIPs() (ips []net.IP, err error) { - // Pass in empty filter device name for list all LOCAL type addresses. - nodeAddress, err := r.nl.GetLocalAddresses("", DefaultDummyDevice) + + nodeAddress, err := r.nl.GetAllLocalAddresses() if err != nil { return nil, fmt.Errorf("error listing LOCAL type addresses from host, error: %v", err) } + + // We must exclude the addresses on the IPVS dummy interface + bindedAddress, err := r.BindedIPs() + if err != nil { + return nil, err + } + ipset := nodeAddress.Difference(bindedAddress) + // translate ip string to IP - for _, ipStr := range nodeAddress.UnsortedList() { + for _, ipStr := range ipset.UnsortedList() { a := netutils.ParseIPSloppy(ipStr) - if a.IsLoopback() { - continue - } ips = append(ips, a) } return ips, nil @@ -323,7 +321,7 @@ func (r *realIPGetter) NodeIPs() (ips []net.IP, err error) { // BindedIPs returns all addresses that are binded to the IPVS dummy interface kube-ipvs0 func (r *realIPGetter) BindedIPs() (sets.String, error) { - return r.nl.GetLocalAddresses(DefaultDummyDevice, "") + return r.nl.GetLocalAddresses(DefaultDummyDevice) } // Proxier implements proxy.Provider diff --git a/pkg/proxy/ipvs/proxier_test.go b/pkg/proxy/ipvs/proxier_test.go index 0654d25c9c8..b25431e3b2a 100644 --- a/pkg/proxy/ipvs/proxier_test.go +++ b/pkg/proxy/ipvs/proxier_test.go @@ -385,6 +385,7 @@ func TestCanUseIPVSProxier(t *testing.T) { func TestGetNodeIPs(t *testing.T) { testCases := []struct { + isIPv6 bool devAddresses map[string][]string expectIPs []string }{ @@ -405,22 +406,22 @@ func TestGetNodeIPs(t *testing.T) { }, // case 3 { - devAddresses: map[string][]string{"encap0": {"10.20.30.40"}, "lo": {"127.0.0.1"}, "docker0": {"172.17.0.1"}}, + devAddresses: map[string][]string{"encap0": {"10.20.30.40", "fe80::200:ff:fe01:1"}, "lo": {"127.0.0.1", "::1"}, "docker0": {"172.17.0.1"}}, expectIPs: []string{"10.20.30.40", "172.17.0.1"}, }, // case 4 { - devAddresses: map[string][]string{"encaps9": {"10.20.30.40"}, "lo": {"127.0.0.1"}, "encap7": {"10.20.30.31"}}, + devAddresses: map[string][]string{"encaps9": {"10.20.30.40"}, "lo": {"127.0.0.1", "::1"}, "encap7": {"1000::", "10.20.30.31"}}, expectIPs: []string{"10.20.30.40", "10.20.30.31"}, }, // case 5 { - devAddresses: map[string][]string{"kube-ipvs0": {"1.2.3.4"}, "lo": {"127.0.0.1"}, "encap7": {"10.20.30.31"}}, + devAddresses: map[string][]string{"kube-ipvs0": {"2000::", "1.2.3.4"}, "lo": {"127.0.0.1", "::1"}, "encap7": {"1000::", "10.20.30.31"}}, expectIPs: []string{"10.20.30.31"}, }, // case 6 { - devAddresses: map[string][]string{"kube-ipvs0": {"1.2.3.4", "2.3.4.5"}, "lo": {"127.0.0.1"}}, + devAddresses: map[string][]string{"kube-ipvs0": {"1.2.3.4", "2.3.4.5"}, "lo": {"127.0.0.1", "::1"}}, expectIPs: []string{}, }, // case 7 @@ -430,18 +431,31 @@ func TestGetNodeIPs(t *testing.T) { }, // case 8 { - devAddresses: map[string][]string{"kube-ipvs0": {"1.2.3.4", "2.3.4.5"}, "eth5": {"3.4.5.6"}, "lo": {"127.0.0.1"}}, + devAddresses: map[string][]string{"kube-ipvs0": {"1.2.3.4", "2.3.4.5"}, "eth5": {"3.4.5.6"}, "lo": {"127.0.0.1", "::1"}}, expectIPs: []string{"3.4.5.6"}, }, // case 9 { - devAddresses: map[string][]string{"ipvs0": {"1.2.3.4"}, "lo": {"127.0.0.1"}, "encap7": {"10.20.30.31"}}, + devAddresses: map[string][]string{"ipvs0": {"1.2.3.4"}, "lo": {"127.0.0.1", "::1"}, "encap7": {"10.20.30.31"}}, expectIPs: []string{"10.20.30.31", "1.2.3.4"}, }, + // case 10 + { + isIPv6: true, + devAddresses: map[string][]string{"ipvs0": {"1.2.3.4", "1000::"}, "lo": {"127.0.0.1", "::1"}, "encap7": {"10.20.30.31", "2000::", "fe80::200:ff:fe01:1"}}, + expectIPs: []string{"1000::", "2000::"}, + }, + // case 11 + { + isIPv6: true, + devAddresses: map[string][]string{"ipvs0": {"1.2.3.4", "1000::"}, "lo": {"127.0.0.1", "::1"}, "encap7": {"10.20.30.31", "2000::", "fe80::200:ff:fe01:1"}, "kube-ipvs0": {"1.2.3.4", "2.3.4.5", "2000::"}}, + expectIPs: []string{"1000::"}, + }, } for i := range testCases { fake := netlinktest.NewFakeNetlinkHandle() + fake.IsIPv6 = testCases[i].isIPv6 for dev, addresses := range testCases[i].devAddresses { fake.SetLocalAddresses(dev, addresses...) } @@ -5354,3 +5368,94 @@ func Test_EndpointSliceOnlyReadyAndTerminatingLocalWithFeatureGateDisabled(t *te assert.Nil(t, rsErr2, "Expected no error getting real servers") assert.Len(t, realServers2, 0, "Expected 0 real servers") } + +func TestIpIsValidForSet(t *testing.T) { + testCases := []struct { + isIPv6 bool + ip string + res bool + }{ + { + false, + "127.0.0.1", + false, + }, + { + false, + "127.0.0.0", + false, + }, + { + false, + "127.6.7.8", + false, + }, + { + false, + "8.8.8.8", + true, + }, + { + false, + "192.168.0.1", + true, + }, + { + false, + "169.254.0.0", + true, + }, + { + false, + "::ffff:169.254.0.0", // IPv6 mapped IPv4 + true, + }, + { + false, + "1000::", + false, + }, + // IPv6 + { + true, + "::1", + false, + }, + { + true, + "1000::", + true, + }, + { + true, + "fe80::200:ff:fe01:1", + false, + }, + { + true, + "8.8.8.8", + false, + }, + { + true, + "::ffff:8.8.8.8", + false, + }, + } + + for _, tc := range testCases { + v := &netlinkHandle{} + v.isIPv6 = tc.isIPv6 + ip := netutils.ParseIPSloppy(tc.ip) + if ip == nil { + t.Errorf("Parse error: %s", tc.ip) + } + if v.isValidForSet(ip) != tc.res { + if tc.isIPv6 { + t.Errorf("IPv6: %s", tc.ip) + } else { + t.Errorf("IPv4: %s", tc.ip) + } + } + } +} diff --git a/pkg/proxy/ipvs/testing/fake.go b/pkg/proxy/ipvs/testing/fake.go index 93886f6be7b..7ece461b436 100644 --- a/pkg/proxy/ipvs/testing/fake.go +++ b/pkg/proxy/ipvs/testing/fake.go @@ -18,6 +18,7 @@ package testing import ( "fmt" + "k8s.io/utils/net" "k8s.io/apimachinery/pkg/util/sets" ) @@ -27,6 +28,8 @@ type FakeNetlinkHandle struct { // localAddresses is a network interface name to all of its IP addresses map, e.g. // eth0 -> [1.2.3.4, 10.20.30.40] localAddresses map[string][]string + + IsIPv6 bool } // NewFakeNetlinkHandle will create a new FakeNetlinkHandle @@ -112,23 +115,25 @@ func (h *FakeNetlinkHandle) ListBindAddress(devName string) ([]string, error) { } // GetLocalAddresses is a mock implementation -func (h *FakeNetlinkHandle) GetLocalAddresses(dev, filterDev string) (sets.String, error) { +func (h *FakeNetlinkHandle) GetLocalAddresses(dev string) (sets.String, error) { res := sets.NewString() - if len(dev) != 0 { - // list all addresses from a given network interface. - for _, addr := range h.localAddresses[dev] { + // list all addresses from a given network interface. + for _, addr := range h.localAddresses[dev] { + if h.isValidForSet(addr) { res.Insert(addr) } - return res, nil } - // If filterDev is not given, will list all addresses from all available network interface. + return res, nil +} +func (h *FakeNetlinkHandle) GetAllLocalAddresses() (sets.String, error) { + res := sets.NewString() + // List all addresses from all available network interfaces. for linkName := range h.localAddresses { - if linkName == filterDev { - continue - } // list all addresses from a given network interface. for _, addr := range h.localAddresses[linkName] { - res.Insert(addr) + if h.isValidForSet(addr) { + res.Insert(addr) + } } } return res, nil @@ -146,3 +151,17 @@ func (h *FakeNetlinkHandle) SetLocalAddresses(dev string, ips ...string) error { h.localAddresses[dev] = append(h.localAddresses[dev], ips...) return nil } + +func (h *FakeNetlinkHandle) isValidForSet(ipString string) bool { + ip := net.ParseIPSloppy(ipString) + if h.IsIPv6 != (ip.To4() == nil) { + return false + } + if h.IsIPv6 && ip.IsLinkLocalUnicast() { + return false + } + if ip.IsLoopback() { + return false + } + return true +} diff --git a/pkg/proxy/ipvs/testing/fake_test.go b/pkg/proxy/ipvs/testing/fake_test.go index 2c8a5265ba6..1c7a16d97fd 100644 --- a/pkg/proxy/ipvs/testing/fake_test.go +++ b/pkg/proxy/ipvs/testing/fake_test.go @@ -27,22 +27,22 @@ func TestSetGetLocalAddresses(t *testing.T) { fake := NewFakeNetlinkHandle() fake.SetLocalAddresses("eth0", "1.2.3.4") expected := sets.NewString("1.2.3.4") - addr, _ := fake.GetLocalAddresses("eth0", "") + addr, _ := fake.GetLocalAddresses("eth0") if !reflect.DeepEqual(expected, addr) { t.Errorf("Unexpected mismatch, expected: %v, got: %v", expected, addr) } - list, _ := fake.GetLocalAddresses("", "") + list, _ := fake.GetAllLocalAddresses() if !reflect.DeepEqual(expected, list) { t.Errorf("Unexpected mismatch, expected: %v, got: %v", expected, list) } fake.SetLocalAddresses("lo", "127.0.0.1") - expected = sets.NewString("127.0.0.1") - addr, _ = fake.GetLocalAddresses("lo", "") + expected = sets.NewString() + addr, _ = fake.GetLocalAddresses("lo") if !reflect.DeepEqual(expected, addr) { t.Errorf("Unexpected mismatch, expected: %v, got: %v", expected, addr) } - list, _ = fake.GetLocalAddresses("", "") - expected = sets.NewString("1.2.3.4", "127.0.0.1") + list, _ = fake.GetAllLocalAddresses() + expected = sets.NewString("1.2.3.4") if !reflect.DeepEqual(expected, list) { t.Errorf("Unexpected mismatch, expected: %v, got: %v", expected, list) } diff --git a/pkg/proxy/util/utils.go b/pkg/proxy/util/utils.go index 5cee78ccdb0..f1d3c10e6f3 100644 --- a/pkg/proxy/util/utils.go +++ b/pkg/proxy/util/utils.go @@ -252,6 +252,27 @@ func GetNodeAddresses(cidrs []string, nw NetworkInterfacer) (sets.String, error) return uniqueAddressList, nil } +// AddressSet validates the addresses in the slice using the "isValid" function. +// Addresses that pass the validation are returned as a string Set. +func AddressSet(isValid func(ip net.IP) bool, addrs []net.Addr) sets.String { + ips := sets.NewString() + for _, a := range addrs { + var ip net.IP + switch v := a.(type) { + case *net.IPAddr: + ip = v.IP + case *net.IPNet: + ip = v.IP + default: + continue + } + if isValid(ip) { + ips.Insert(ip.String()) + } + } + return ips +} + // LogAndEmitIncorrectIPVersionEvent logs and emits incorrect IP version event. func LogAndEmitIncorrectIPVersionEvent(recorder events.EventRecorder, fieldName, fieldValue, svcNamespace, svcName string, svcUID types.UID) { errMsg := fmt.Sprintf("%s in %s has incorrect IP version", fieldValue, fieldName) diff --git a/pkg/proxy/util/utils_test.go b/pkg/proxy/util/utils_test.go index c61c7c2cb51..d5fced8e86f 100644 --- a/pkg/proxy/util/utils_test.go +++ b/pkg/proxy/util/utils_test.go @@ -1283,3 +1283,119 @@ func randSeq() string { } return string(b) } + +func mustParseIPAddr(str string) net.Addr { + a, err := net.ResolveIPAddr("ip", str) + if err != nil { + panic("mustParseIPAddr") + } + return a +} +func mustParseIPNet(str string) net.Addr { + _, n, err := netutils.ParseCIDRSloppy(str) + if err != nil { + panic("mustParseIPNet") + } + return n +} +func mustParseUnix(str string) net.Addr { + n, err := net.ResolveUnixAddr("unix", str) + if err != nil { + panic("mustParseUnix") + } + return n +} + +type cidrValidator struct { + cidr *net.IPNet +} + +func (v *cidrValidator) isValid(ip net.IP) bool { + return v.cidr.Contains(ip) +} +func newCidrValidator(cidr string) func(ip net.IP) bool { + _, n, err := netutils.ParseCIDRSloppy(cidr) + if err != nil { + panic("mustParseIPNet") + } + obj := cidrValidator{n} + return obj.isValid +} + +func TestAddressSet(t *testing.T) { + testCases := []struct { + name string + validator func(ip net.IP) bool + input []net.Addr + expected sets.String + }{ + { + "Empty", + func(ip net.IP) bool { return false }, + nil, + sets.NewString(), + }, + { + "Reject IPAddr x 2", + func(ip net.IP) bool { return false }, + []net.Addr{ + mustParseIPAddr("8.8.8.8"), + mustParseIPAddr("1000::"), + }, + sets.NewString(), + }, + { + "Accept IPAddr x 2", + func(ip net.IP) bool { return true }, + []net.Addr{ + mustParseIPAddr("8.8.8.8"), + mustParseIPAddr("1000::"), + }, + sets.NewString("8.8.8.8", "1000::"), + }, + { + "Accept IPNet x 2", + func(ip net.IP) bool { return true }, + []net.Addr{ + mustParseIPNet("8.8.8.8/32"), + mustParseIPNet("1000::/128"), + }, + sets.NewString("8.8.8.8", "1000::"), + }, + { + "Accept Unix x 2", + func(ip net.IP) bool { return true }, + []net.Addr{ + mustParseUnix("/tmp/sock1"), + mustParseUnix("/tmp/sock2"), + }, + sets.NewString(), + }, + { + "Cidr IPv4", + newCidrValidator("192.168.1.0/24"), + []net.Addr{ + mustParseIPAddr("8.8.8.8"), + mustParseIPAddr("1000::"), + mustParseIPAddr("192.168.1.1"), + }, + sets.NewString("192.168.1.1"), + }, + { + "Cidr IPv6", + newCidrValidator("1000::/64"), + []net.Addr{ + mustParseIPAddr("8.8.8.8"), + mustParseIPAddr("1000::"), + mustParseIPAddr("192.168.1.1"), + }, + sets.NewString("1000::"), + }, + } + + for _, tc := range testCases { + if !tc.expected.Equal(AddressSet(tc.validator, tc.input)) { + t.Errorf("%s", tc.name) + } + } +}