dockershim hostport respect IPFamily

This commit is contained in:
Antonio Ojea 2021-02-04 10:03:24 +01:00
parent f7d86e8b1c
commit ad4776ba54
5 changed files with 101 additions and 46 deletions

View File

@ -148,14 +148,14 @@ func (f *fakeIPTables) ensureRule(position utiliptables.RulePosition, tableName
return true, nil return true, nil
} }
if position == utiliptables.Prepend { switch position {
case utiliptables.Prepend:
chain.rules = append([]string{rule}, chain.rules...) chain.rules = append([]string{rule}, chain.rules...)
} else if position == utiliptables.Append { case utiliptables.Append:
chain.rules = append(chain.rules, rule) chain.rules = append(chain.rules, rule)
} else { default:
return false, fmt.Errorf("unknown position argument %q", position) return false, fmt.Errorf("unknown position argument %q", position)
} }
return false, nil return false, nil
} }
@ -185,7 +185,7 @@ func normalizeRule(rule string) (string, error) {
// Normalize un-prefixed IP addresses like iptables does // Normalize un-prefixed IP addresses like iptables does
if net.ParseIP(arg) != nil { if net.ParseIP(arg) != nil {
arg = arg + "/32" arg += "/32"
} }
if len(normalized) > 0 { if len(normalized) > 0 {
@ -281,7 +281,10 @@ func (f *fakeIPTables) restore(restoreTableName utiliptables.Table, data []byte,
if strings.HasPrefix(line, ":") { if strings.HasPrefix(line, ":") {
chainName := utiliptables.Chain(strings.Split(line[1:], " ")[0]) chainName := utiliptables.Chain(strings.Split(line[1:], " ")[0])
if flush == utiliptables.FlushTables { if flush == utiliptables.FlushTables {
table, chain, _ := f.getChain(tableName, chainName) table, chain, err := f.getChain(tableName, chainName)
if err != nil {
return err
}
if chain != nil { if chain != nil {
delete(table.chains, string(chainName)) delete(table.chains, string(chainName))
} }

View File

@ -54,7 +54,17 @@ type PodPortMapping struct {
IP net.IP IP net.IP
} }
// ipFamily refers to a specific family if not empty, i.e. "4" or "6".
type ipFamily string
// Constants for valid IPFamily:
const (
IPv4 ipFamily = "4"
IPv6 ipFamily = "6"
)
type hostport struct { type hostport struct {
ipFamily ipFamily
ip string ip string
port int32 port int32
protocol string protocol string
@ -84,17 +94,19 @@ func openLocalPort(hp *hostport) (closeable, error) {
address := net.JoinHostPort(hp.ip, strconv.Itoa(int(hp.port))) address := net.JoinHostPort(hp.ip, strconv.Itoa(int(hp.port)))
switch hp.protocol { switch hp.protocol {
case "tcp": case "tcp":
listener, err := net.Listen("tcp", address) network := "tcp" + string(hp.ipFamily)
listener, err := net.Listen(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
socket = listener socket = listener
case "udp": case "udp":
addr, err := net.ResolveUDPAddr("udp", address) network := "udp" + string(hp.ipFamily)
addr, err := net.ResolveUDPAddr(network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := net.ListenUDP("udp", addr) conn, err := net.ListenUDP(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -107,8 +119,9 @@ func openLocalPort(hp *hostport) (closeable, error) {
} }
// portMappingToHostport creates hostport structure based on input portmapping // portMappingToHostport creates hostport structure based on input portmapping
func portMappingToHostport(portMapping *PortMapping) hostport { func portMappingToHostport(portMapping *PortMapping, family ipFamily) hostport {
return hostport{ return hostport{
ipFamily: family,
ip: portMapping.HostIP, ip: portMapping.HostIP,
port: portMapping.HostPort, port: portMapping.HostPort,
protocol: strings.ToLower(string(portMapping.Protocol)), protocol: strings.ToLower(string(portMapping.Protocol)),
@ -129,9 +142,11 @@ func ensureKubeHostportChains(iptables utiliptables.Interface, natInterfaceName
{utiliptables.TableNAT, utiliptables.ChainOutput}, {utiliptables.TableNAT, utiliptables.ChainOutput},
{utiliptables.TableNAT, utiliptables.ChainPrerouting}, {utiliptables.TableNAT, utiliptables.ChainPrerouting},
} }
args := []string{"-m", "comment", "--comment", "kube hostport portals", args := []string{
"-m", "comment", "--comment", "kube hostport portals",
"-m", "addrtype", "--dst-type", "LOCAL", "-m", "addrtype", "--dst-type", "LOCAL",
"-j", string(kubeHostportsChain)} "-j", string(kubeHostportsChain),
}
for _, tc := range tableChainsNeedJumpServices { for _, tc := range tableChainsNeedJumpServices {
// KUBE-HOSTPORTS chain needs to be appended to the system chains. // KUBE-HOSTPORTS chain needs to be appended to the system chains.
// This ensures KUBE-SERVICES chain gets processed first. // This ensures KUBE-SERVICES chain gets processed first.

View File

@ -59,6 +59,7 @@ type hostportManager struct {
mu sync.Mutex mu sync.Mutex
} }
// NewHostportManager creates a new HostPortManager
func NewHostportManager(iptables utiliptables.Interface) HostPortManager { func NewHostportManager(iptables utiliptables.Interface) HostPortManager {
h := &hostportManager{ h := &hostportManager{
hostPortMap: make(map[hostport]closeable), hostPortMap: make(map[hostport]closeable),
@ -78,13 +79,6 @@ func (hm *hostportManager) Add(id string, podPortMapping *PodPortMapping, natInt
return nil return nil
} }
podFullName := getPodFullName(podPortMapping) podFullName := getPodFullName(podPortMapping)
// skip if there is no hostport needed
hostportMappings := gatherHostportMappings(podPortMapping)
if len(hostportMappings) == 0 {
return nil
}
// IP.To16() returns nil if IP is not a valid IPv4 or IPv6 address // IP.To16() returns nil if IP is not a valid IPv4 or IPv6 address
if podPortMapping.IP.To16() == nil { if podPortMapping.IP.To16() == nil {
return fmt.Errorf("invalid or missing IP of pod %s", podFullName) return fmt.Errorf("invalid or missing IP of pod %s", podFullName)
@ -92,11 +86,17 @@ func (hm *hostportManager) Add(id string, podPortMapping *PodPortMapping, natInt
podIP := podPortMapping.IP.String() podIP := podPortMapping.IP.String()
isIPv6 := utilnet.IsIPv6(podPortMapping.IP) isIPv6 := utilnet.IsIPv6(podPortMapping.IP)
// skip if there is no hostport needed
hostportMappings := gatherHostportMappings(podPortMapping, isIPv6)
if len(hostportMappings) == 0 {
return nil
}
if isIPv6 != hm.iptables.IsIPv6() { if isIPv6 != hm.iptables.IsIPv6() {
return fmt.Errorf("HostPortManager IP family mismatch: %v, isIPv6 - %v", podIP, isIPv6) return fmt.Errorf("HostPortManager IP family mismatch: %v, isIPv6 - %v", podIP, isIPv6)
} }
if err = ensureKubeHostportChains(hm.iptables, natInterfaceName); err != nil { if err := ensureKubeHostportChains(hm.iptables, natInterfaceName); err != nil {
return err return err
} }
@ -205,8 +205,8 @@ func (hm *hostportManager) Remove(id string, podPortMapping *PodPortMapping) (er
return nil return nil
} }
hostportMappings := gatherHostportMappings(podPortMapping) hostportMappings := gatherHostportMappings(podPortMapping, hm.iptables.IsIPv6())
if len(hostportMappings) <= 0 { if len(hostportMappings) == 0 {
return nil return nil
} }
@ -238,6 +238,12 @@ func (hm *hostportManager) Remove(id string, podPortMapping *PodPortMapping) (er
} }
} }
// exit if there is nothing to remove
// don´t forget to clean up opened pod host ports
if len(existingChainsToRemove) == 0 {
return hm.closeHostports(hostportMappings)
}
natChains := bytes.NewBuffer(nil) natChains := bytes.NewBuffer(nil)
natRules := bytes.NewBuffer(nil) natRules := bytes.NewBuffer(nil)
writeLine(natChains, "*nat") writeLine(natChains, "*nat")
@ -252,7 +258,7 @@ func (hm *hostportManager) Remove(id string, podPortMapping *PodPortMapping) (er
} }
writeLine(natRules, "COMMIT") writeLine(natRules, "COMMIT")
if err = hm.syncIPTables(append(natChains.Bytes(), natRules.Bytes()...)); err != nil { if err := hm.syncIPTables(append(natChains.Bytes(), natRules.Bytes()...)); err != nil {
return err return err
} }
@ -286,7 +292,12 @@ func (hm *hostportManager) openHostports(podPortMapping *PodPortMapping) (map[ho
continue continue
} }
hp := portMappingToHostport(pm) // HostIP IP family is not handled by this port opener
if pm.HostIP != "" && utilnet.IsIPv6String(pm.HostIP) != hm.iptables.IsIPv6() {
continue
}
hp := portMappingToHostport(pm, hm.getIPFamily())
socket, err := hm.portOpener(&hp) socket, err := hm.portOpener(&hp)
if err != nil { if err != nil {
retErr = fmt.Errorf("cannot open hostport %d for pod %s: %v", pm.HostPort, getPodFullName(podPortMapping), err) retErr = fmt.Errorf("cannot open hostport %d for pod %s: %v", pm.HostPort, getPodFullName(podPortMapping), err)
@ -311,7 +322,7 @@ func (hm *hostportManager) openHostports(podPortMapping *PodPortMapping) (map[ho
func (hm *hostportManager) closeHostports(hostportMappings []*PortMapping) error { func (hm *hostportManager) closeHostports(hostportMappings []*PortMapping) error {
errList := []error{} errList := []error{}
for _, pm := range hostportMappings { for _, pm := range hostportMappings {
hp := portMappingToHostport(pm) hp := portMappingToHostport(pm, hm.getIPFamily())
if socket, ok := hm.hostPortMap[hp]; ok { if socket, ok := hm.hostPortMap[hp]; ok {
klog.V(2).Infof("Closing host port %s", hp.String()) klog.V(2).Infof("Closing host port %s", hp.String())
if err := socket.Close(); err != nil { if err := socket.Close(); err != nil {
@ -326,6 +337,15 @@ func (hm *hostportManager) closeHostports(hostportMappings []*PortMapping) error
return utilerrors.NewAggregate(errList) return utilerrors.NewAggregate(errList)
} }
// getIPFamily returns the hostPortManager IP family
func (hm *hostportManager) getIPFamily() ipFamily {
family := IPv4
if hm.iptables.IsIPv6() {
family = IPv6
}
return family
}
// getHostportChain takes id, hostport and protocol for a pod and returns associated iptables chain. // getHostportChain takes id, hostport and protocol for a pod and returns associated iptables chain.
// This is computed by hashing (sha256) then encoding to base32 and truncating with the prefix // This is computed by hashing (sha256) then encoding to base32 and truncating with the prefix
// "KUBE-HP-". We do this because IPTables Chain Names must be <= 28 chars long, and the longer // "KUBE-HP-". We do this because IPTables Chain Names must be <= 28 chars long, and the longer
@ -339,12 +359,16 @@ func getHostportChain(id string, pm *PortMapping) utiliptables.Chain {
} }
// gatherHostportMappings returns all the PortMappings which has hostport for a pod // gatherHostportMappings returns all the PortMappings which has hostport for a pod
func gatherHostportMappings(podPortMapping *PodPortMapping) []*PortMapping { // it filters the PortMappings that use HostIP and doesn't match the IP family specified
func gatherHostportMappings(podPortMapping *PodPortMapping, isIPv6 bool) []*PortMapping {
mappings := []*PortMapping{} mappings := []*PortMapping{}
for _, pm := range podPortMapping.PortMappings { for _, pm := range podPortMapping.PortMappings {
if pm.HostPort <= 0 { if pm.HostPort <= 0 {
continue continue
} }
if pm.HostIP != "" && utilnet.IsIPv6String(pm.HostIP) != isIPv6 {
continue
}
mappings = append(mappings, pm) mappings = append(mappings, pm)
} }
return mappings return mappings

View File

@ -128,6 +128,7 @@ func TestOpenCloseHostports(t *testing.T) {
} }
iptables := NewFakeIPTables() iptables := NewFakeIPTables()
iptables.protocol = utiliptables.ProtocolIPv4
portOpener := NewFakeSocketManager() portOpener := NewFakeSocketManager()
manager := &hostportManager{ manager := &hostportManager{
hostPortMap: make(map[hostport]closeable), hostPortMap: make(map[hostport]closeable),
@ -151,7 +152,7 @@ func TestOpenCloseHostports(t *testing.T) {
countSctp := 0 countSctp := 0
for _, pm := range tc.podPortMapping.PortMappings { for _, pm := range tc.podPortMapping.PortMappings {
if pm.Protocol == v1.ProtocolSCTP { if pm.Protocol == v1.ProtocolSCTP {
countSctp += 1 countSctp++
} }
} }
assert.EqualValues(t, len(mapping), len(tc.podPortMapping.PortMappings)-countSctp) assert.EqualValues(t, len(mapping), len(tc.podPortMapping.PortMappings)-countSctp)
@ -211,7 +212,8 @@ func TestOpenCloseHostports(t *testing.T) {
{ {
portMappings: []*PortMapping{ portMappings: []*PortMapping{
{HostPort: 9999, Protocol: v1.ProtocolTCP}, {HostPort: 9999, Protocol: v1.ProtocolTCP},
{HostPort: 9999, Protocol: v1.ProtocolUDP}}, {HostPort: 9999, Protocol: v1.ProtocolUDP},
},
}, },
} }
@ -230,6 +232,7 @@ func TestOpenCloseHostports(t *testing.T) {
func TestHostportManager(t *testing.T) { func TestHostportManager(t *testing.T) {
iptables := NewFakeIPTables() iptables := NewFakeIPTables()
iptables.protocol = utiliptables.ProtocolIPv4
portOpener := NewFakeSocketManager() portOpener := NewFakeSocketManager()
manager := &hostportManager{ manager := &hostportManager{
hostPortMap: make(map[hostport]closeable), hostPortMap: make(map[hostport]closeable),
@ -237,7 +240,6 @@ func TestHostportManager(t *testing.T) {
portOpener: portOpener.openFakeSocket, portOpener: portOpener.openFakeSocket,
execer: exec.New(), execer: exec.New(),
} }
testCases := []struct { testCases := []struct {
mapping *PodPortMapping mapping *PodPortMapping
expectError bool expectError bool
@ -318,7 +320,7 @@ func TestHostportManager(t *testing.T) {
mapping: &PodPortMapping{ mapping: &PodPortMapping{
Name: "pod3", Name: "pod3",
Namespace: "ns1", Namespace: "ns1",
IP: net.ParseIP("2001:beef::2"), IP: net.ParseIP("192.168.12.12"),
HostNetwork: false, HostNetwork: false,
PortMappings: []*PortMapping{ PortMappings: []*PortMapping{
{ {
@ -330,7 +332,7 @@ func TestHostportManager(t *testing.T) {
}, },
expectError: true, expectError: true,
}, },
// fail HostPort with PodIP and HostIP using different families // skip HostPort with PodIP and HostIP using different families
{ {
mapping: &PodPortMapping{ mapping: &PodPortMapping{
Name: "pod4", Name: "pod4",
@ -346,7 +348,7 @@ func TestHostportManager(t *testing.T) {
}, },
}, },
}, },
expectError: true, expectError: false,
}, },
// open same HostPort on different IP // open same HostPort on different IP
@ -408,9 +410,15 @@ func TestHostportManager(t *testing.T) {
} }
// Check port opened // Check port opened
expectedPorts := []hostport{{"", 8080, "tcp"}, expectedPorts := []hostport{
{"", 8081, "udp"}, {"", 8443, "tcp"}, {"127.0.0.1", 8888, "tcp"}, {IPv4, "", 8080, "tcp"},
{"127.0.0.2", 8888, "tcp"}, {"", 9999, "tcp"}, {"", 9999, "udp"}} {IPv4, "", 8081, "udp"},
{IPv4, "", 8443, "tcp"},
{IPv4, "127.0.0.1", 8888, "tcp"},
{IPv4, "127.0.0.2", 8888, "tcp"},
{IPv4, "", 9999, "tcp"},
{IPv4, "", 9999, "udp"},
}
openedPorts := make(map[hostport]bool) openedPorts := make(map[hostport]bool)
for hp, port := range portOpener.mem { for hp, port := range portOpener.mem {
if !port.closed { if !port.closed {
@ -499,8 +507,10 @@ func TestHostportManager(t *testing.T) {
remainingChains[strings.TrimSpace(line)] = true remainingChains[strings.TrimSpace(line)] = true
} }
} }
expectDeletedChains := []string{"KUBE-HP-4YVONL46AKYWSKS3", "KUBE-HP-7THKRFSEH4GIIXK7", "KUBE-HP-5N7UH5JAXCVP5UJR", expectDeletedChains := []string{
"KUBE-HP-TUKTZ736U5JD5UTK", "KUBE-HP-CAAJ45HDITK7ARGM", "KUBE-HP-WFUNFVXVDLD5ZVXN", "KUBE-HP-4MFWH2F2NAOMYD6A"} "KUBE-HP-4YVONL46AKYWSKS3", "KUBE-HP-7THKRFSEH4GIIXK7", "KUBE-HP-5N7UH5JAXCVP5UJR",
"KUBE-HP-TUKTZ736U5JD5UTK", "KUBE-HP-CAAJ45HDITK7ARGM", "KUBE-HP-WFUNFVXVDLD5ZVXN", "KUBE-HP-4MFWH2F2NAOMYD6A",
}
for _, chain := range expectDeletedChains { for _, chain := range expectDeletedChains {
_, ok := remainingChains[chain] _, ok := remainingChains[chain]
assert.EqualValues(t, false, ok) assert.EqualValues(t, false, ok)
@ -537,7 +547,6 @@ func TestHostportManagerIPv6(t *testing.T) {
portOpener: portOpener.openFakeSocket, portOpener: portOpener.openFakeSocket,
execer: exec.New(), execer: exec.New(),
} }
testCases := []struct { testCases := []struct {
mapping *PodPortMapping mapping *PodPortMapping
expectError bool expectError bool
@ -639,7 +648,7 @@ func TestHostportManagerIPv6(t *testing.T) {
} }
// Check port opened // Check port opened
expectedPorts := []hostport{{"", 8080, "tcp"}, {"", 8081, "udp"}, {"", 8443, "tcp"}} expectedPorts := []hostport{{IPv6, "", 8080, "tcp"}, {IPv6, "", 8081, "udp"}, {IPv6, "", 8443, "tcp"}}
openedPorts := make(map[hostport]bool) openedPorts := make(map[hostport]bool)
for hp, port := range portOpener.mem { for hp, port := range portOpener.mem {
if !port.closed { if !port.closed {
@ -657,7 +666,7 @@ func TestHostportManagerIPv6(t *testing.T) {
err := iptables.SaveInto(utiliptables.TableNAT, raw) err := iptables.SaveInto(utiliptables.TableNAT, raw)
assert.NoError(t, err) assert.NoError(t, err)
lines := strings.Split(string(raw.Bytes()), "\n") lines := strings.Split(raw.String(), "\n")
expectedLines := map[string]bool{ expectedLines := map[string]bool{
`*nat`: true, `*nat`: true,
`:KUBE-HOSTPORTS - [0:0]`: true, `:KUBE-HOSTPORTS - [0:0]`: true,
@ -704,7 +713,7 @@ func TestHostportManagerIPv6(t *testing.T) {
raw.Reset() raw.Reset()
err = iptables.SaveInto(utiliptables.TableNAT, raw) err = iptables.SaveInto(utiliptables.TableNAT, raw)
assert.NoError(t, err) assert.NoError(t, err)
lines = strings.Split(string(raw.Bytes()), "\n") lines = strings.Split(raw.String(), "\n")
remainingChains := make(map[string]bool) remainingChains := make(map[string]bool)
for _, line := range lines { for _, line := range lines {
if strings.HasPrefix(line, ":") { if strings.HasPrefix(line, ":") {

View File

@ -27,15 +27,15 @@ import (
) )
type fakeSocket struct { type fakeSocket struct {
ip string closed bool
port int32 port int32
protocol string protocol string
closed bool ip string
} }
func (f *fakeSocket) Close() error { func (f *fakeSocket) Close() error {
if f.closed { if f.closed {
return fmt.Errorf("Socket %q.%s already closed!", f.port, f.protocol) return fmt.Errorf("socket %q.%s already closed", f.port, f.protocol)
} }
f.closed = true f.closed = true
return nil return nil
@ -53,7 +53,12 @@ func (f *fakeSocketManager) openFakeSocket(hp *hostport) (closeable, error) {
if socket, ok := f.mem[*hp]; ok && !socket.closed { if socket, ok := f.mem[*hp]; ok && !socket.closed {
return nil, fmt.Errorf("hostport is occupied") return nil, fmt.Errorf("hostport is occupied")
} }
fs := &fakeSocket{hp.ip, hp.port, hp.protocol, false} fs := &fakeSocket{
port: hp.port,
protocol: hp.protocol,
closed: false,
ip: hp.ip,
}
f.mem[*hp] = fs f.mem[*hp] = fs
return fs, nil return fs, nil
} }
@ -81,5 +86,4 @@ func TestEnsureKubeHostportChains(t *testing.T) {
assert.EqualValues(t, len(chain.rules), 1) assert.EqualValues(t, len(chain.rules), 1)
assert.Contains(t, chain.rules[0], jumpRule) assert.Contains(t, chain.rules[0], jumpRule)
} }
} }