Rename GetNodeAddresses to GetNodeIPs, return net.IP

This commit is contained in:
Dan Winship 2023-03-03 17:43:15 -05:00
parent 2ca215fd99
commit a744a186b6
6 changed files with 65 additions and 46 deletions

View File

@ -178,7 +178,7 @@ func TestServer(t *testing.T) {
if len(listener.openPorts) != 1 {
t.Errorf("expected 1 open port, got %d\n%s", len(listener.openPorts), dump.Pretty(listener.openPorts))
}
if !listener.hasPort(":9376") {
if !listener.hasPort("0.0.0.0:9376") {
t.Errorf("expected port :9376 to be open\n%s", dump.Pretty(listener.openPorts))
}
// test the handler

View File

@ -32,7 +32,6 @@ import (
api "k8s.io/kubernetes/pkg/apis/core"
utilerrors "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/sets"
utilproxy "k8s.io/kubernetes/pkg/proxy/util"
)
@ -59,18 +58,16 @@ type proxierHealthChecker interface {
}
func newServiceHealthServer(hostname string, recorder events.EventRecorder, listener listener, factory httpServerFactory, nodePortAddresses *utilproxy.NodePortAddresses, healthzServer proxierHealthChecker) ServiceHealthServer {
var nodeAddresses sets.Set[string]
// It doesn't matter whether we listen on "0.0.0.0", "::", or ""; go
// treats them all the same.
nodeIPs := []net.IP{net.IPv4zero}
// if any of the addresses is zero cidr then we listen
// to old style :<port>
if nodePortAddresses.MatchAll() {
nodeAddresses = sets.New("")
if !nodePortAddresses.MatchAll() {
ips, err := nodePortAddresses.GetNodeIPs(utilproxy.RealNetwork{})
if err == nil {
nodeIPs = ips
} else {
var err error
nodeAddresses, err = nodePortAddresses.GetNodeAddresses(utilproxy.RealNetwork{})
if err != nil || nodeAddresses.Len() == 0 {
klog.ErrorS(err, "Failed to get node ip address matching node port addresses, health check port will listen to all node addresses", "nodePortAddresses", nodePortAddresses)
nodeAddresses = sets.New(utilproxy.IPv4ZeroCIDR)
}
}
@ -81,7 +78,7 @@ func newServiceHealthServer(hostname string, recorder events.EventRecorder, list
httpFactory: factory,
healthzServer: healthzServer,
services: map[types.NamespacedName]*hcInstance{},
nodeAddresses: nodeAddresses,
nodeIPs: nodeIPs,
}
}
@ -93,7 +90,7 @@ func NewServiceHealthServer(hostname string, recorder events.EventRecorder, node
type server struct {
hostname string
// node addresses where health check port will listen on
nodeAddresses sets.Set[string]
nodeIPs []net.IP
recorder events.EventRecorder // can be nil
listener listener
httpFactory httpServerFactory
@ -167,12 +164,11 @@ func (hcI *hcInstance) listenAndServeAll(hcs *server) error {
var err error
var listener net.Listener
addresses := hcs.nodeAddresses.UnsortedList()
hcI.httpServers = make([]httpServer, 0, len(addresses))
hcI.httpServers = make([]httpServer, 0, len(hcs.nodeIPs))
// for each of the node addresses start listening and serving
for _, address := range addresses {
addr := net.JoinHostPort(address, fmt.Sprint(hcI.port))
for _, ip := range hcs.nodeIPs {
addr := net.JoinHostPort(ip.String(), fmt.Sprint(hcI.port))
// create http server
httpSrv := hcs.httpFactory.New(addr, hcHandler{name: hcI.nsn, hcs: hcs})
// start listener

View File

@ -1437,21 +1437,21 @@ func (proxier *Proxier) syncProxyRules() {
destinations,
"-j", string(kubeNodePortsChain))
} else {
nodeAddresses, err := proxier.nodePortAddresses.GetNodeAddresses(proxier.networkInterfacer)
nodeIPs, err := proxier.nodePortAddresses.GetNodeIPs(proxier.networkInterfacer)
if err != nil {
klog.ErrorS(err, "Failed to get node ip address matching nodeport cidrs, services with nodeport may not work as intended", "CIDRs", proxier.nodePortAddresses)
}
for address := range nodeAddresses {
for _, ip := range nodeIPs {
// For ipv6, Regardless of the value of localhostNodePorts is true or false, we should disallow access
// to the nodePort via lookBack address.
if isIPv6 && utilproxy.IsLoopBack(address) {
klog.ErrorS(nil, "disallow nodePort services to be accessed via ipv6 localhost address", "IP", address)
if isIPv6 && ip.IsLoopback() {
klog.ErrorS(nil, "disallow nodePort services to be accessed via ipv6 localhost address", "IP", ip.String())
continue
}
// For ipv4, When localhostNodePorts is set to false, Ignore ipv4 lookBack address
if !isIPv6 && utilproxy.IsLoopBack(address) && !proxier.localhostNodePorts {
klog.ErrorS(nil, "disallow nodePort services to be accessed via ipv4 localhost address", "IP", address)
if !isIPv6 && ip.IsLoopback() && !proxier.localhostNodePorts {
klog.ErrorS(nil, "disallow nodePort services to be accessed via ipv4 localhost address", "IP", ip.String())
continue
}
@ -1459,7 +1459,7 @@ func (proxier *Proxier) syncProxyRules() {
proxier.natRules.Write(
"-A", string(kubeServicesChain),
"-m", "comment", "--comment", `"kubernetes service nodeports; NOTE: this must be the last rule in this chain"`,
"-d", address,
"-d", ip.String(),
"-j", string(kubeNodePortsChain))
}
}

View File

@ -1011,16 +1011,14 @@ func (proxier *Proxier) syncProxyRules() {
nodeIPs = append(nodeIPs, netutils.ParseIPSloppy(ipStr))
}
} else {
nodeAddrSet, err := proxier.nodePortAddresses.GetNodeAddresses(proxier.networkInterfacer)
allNodeIPs, err := proxier.nodePortAddresses.GetNodeIPs(proxier.networkInterfacer)
if err != nil {
klog.ErrorS(err, "Failed to get node IP address matching nodeport cidr")
} else {
for address := range nodeAddrSet {
a := netutils.ParseIPSloppy(address)
if a.IsLoopback() {
continue
for _, ip := range allNodeIPs {
if !ip.IsLoopback() {
nodeIPs = append(nodeIPs, ip)
}
nodeIPs = append(nodeIPs, a)
}
}
}

View File

@ -21,7 +21,6 @@ import (
"net"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/sets"
netutils "k8s.io/utils/net"
)
@ -91,16 +90,17 @@ func (npa *NodePortAddresses) MatchAll() bool {
return npa.matchAll
}
// GetNodeAddresses return all matched node IP addresses for npa's CIDRs. If no matching
// GetNodeIPs return all matched node IP addresses for npa's CIDRs. If no matching
// IPs are found, it returns an empty list.
// NetworkInterfacer is injected for test purpose.
func (npa *NodePortAddresses) GetNodeAddresses(nw NetworkInterfacer) (sets.Set[string], error) {
func (npa *NodePortAddresses) GetNodeIPs(nw NetworkInterfacer) ([]net.IP, error) {
addrs, err := nw.InterfaceAddrs()
if err != nil {
return nil, fmt.Errorf("error listing all interfaceAddrs from host, error: %v", err)
}
uniqueAddressList := sets.New[string]()
// Use a map to dedup matches
addresses := make(map[string]net.IP)
for _, cidr := range npa.cidrs {
for _, addr := range addrs {
var ip net.IP
@ -115,12 +115,17 @@ func (npa *NodePortAddresses) GetNodeAddresses(nw NetworkInterfacer) (sets.Set[s
}
if cidr.Contains(ip) {
uniqueAddressList.Insert(ip.String())
addresses[ip.String()] = ip
}
}
}
return uniqueAddressList, nil
ips := make([]net.IP, 0, len(addresses))
for _, ip := range addresses {
ips = append(ips, ip)
}
return ips, nil
}
// ContainsIPv4Loopback returns true if npa's CIDRs contain an IPv4 loopback address.

View File

@ -17,6 +17,7 @@ limitations under the License.
package util
import (
"fmt"
"net"
"testing"
@ -31,7 +32,24 @@ type InterfaceAddrsPair struct {
addrs []net.Addr
}
func TestGetNodeAddresses(t *testing.T) {
func checkNodeIPs(expected sets.Set[string], actual []net.IP) error {
notFound := expected.Clone()
extra := sets.New[string]()
for _, ip := range actual {
str := ip.String()
if notFound.Has(str) {
notFound.Delete(str)
} else {
extra.Insert(str)
}
}
if len(notFound) != 0 || len(extra) != 0 {
return fmt.Errorf("not found: %v, extra: %v", notFound.UnsortedList(), extra.UnsortedList())
}
return nil
}
func TestGetNodeIPs(t *testing.T) {
type expectation struct {
matchAll bool
ips sets.Set[string]
@ -367,16 +385,18 @@ func TestGetNodeAddresses(t *testing.T) {
t.Errorf("unexpected MatchAll(%s), expected: %v", family, tc.expected[family].matchAll)
}
addrList, err := npa.GetNodeAddresses(nw)
ips, err := npa.GetNodeIPs(nw)
expectedIPs := tc.expected[family].ips
// The fake InterfaceAddrs() never returns an error, so
// the only error GetNodeAddresses will return is "no
// the only error GetNodeIPs will return is "no
// addresses found".
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !addrList.Equal(tc.expected[family].ips) {
t.Errorf("unexpected mismatch for %s, expected: %v, got: %v", family, tc.expected[family].ips, addrList)
err = checkNodeIPs(expectedIPs, ips)
if err != nil {
t.Errorf("unexpected mismatch for %s: %v", family, err)
}
}
})