diff --git a/pkg/kubelet/network/hostport/hostport.go b/pkg/kubelet/network/hostport/hostport.go index cdfeb113553..800589e9eb6 100644 --- a/pkg/kubelet/network/hostport/hostport.go +++ b/pkg/kubelet/network/hostport/hostport.go @@ -126,38 +126,6 @@ func openLocalPort(hp *hostport) (closeable, error) { return socket, nil } -// openHostports opens all given hostports using the given hostportOpener -// If encounter any error, clean up and return the error -// If all ports are opened successfully, return the hostport and socket mapping -// TODO: move openHostports and closeHostports into a common struct -func openHostports(portOpener hostportOpener, podPortMapping *PodPortMapping) (map[hostport]closeable, error) { - var retErr error - ports := make(map[hostport]closeable) - for _, pm := range podPortMapping.PortMappings { - if pm.HostPort <= 0 { - continue - } - hp := portMappingToHostport(pm) - socket, err := portOpener(&hp) - if err != nil { - retErr = fmt.Errorf("cannot open hostport %d for pod %s: %v", pm.HostPort, getPodFullName(podPortMapping), err) - break - } - ports[hp] = socket - } - - // If encounter any error, close all hostports that just got opened. - if retErr != nil { - for hp, socket := range ports { - if err := socket.Close(); err != nil { - glog.Errorf("Cannot clean up hostport %d for pod %s: %v", hp.port, getPodFullName(podPortMapping), err) - } - } - return nil, retErr - } - return ports, nil -} - // portMappingToHostport creates hostport structure based on input portmapping func portMappingToHostport(portMapping *PortMapping) hostport { return hostport{ diff --git a/pkg/kubelet/network/hostport/hostport_manager.go b/pkg/kubelet/network/hostport/hostport_manager.go index 3177ac5adff..6d85c75efed 100644 --- a/pkg/kubelet/network/hostport/hostport_manager.go +++ b/pkg/kubelet/network/hostport/hostport_manager.go @@ -96,7 +96,7 @@ func (hm *hostportManager) Add(id string, podPortMapping *PodPortMapping, natInt defer hm.mu.Unlock() // try to open hostports - ports, err := openHostports(hm.portOpener, podPortMapping) + ports, err := hm.openHostports(podPortMapping) if err != nil { return err } @@ -254,8 +254,38 @@ func (hm *hostportManager) syncIPTables(lines []byte) error { return nil } +// openHostports opens all given hostports using the given hostportOpener +// If encounter any error, clean up and return the error +// If all ports are opened successfully, return the hostport and socket mapping +func (hm *hostportManager) openHostports(podPortMapping *PodPortMapping) (map[hostport]closeable, error) { + var retErr error + ports := make(map[hostport]closeable) + for _, pm := range podPortMapping.PortMappings { + if pm.HostPort <= 0 { + continue + } + hp := portMappingToHostport(pm) + socket, err := hm.portOpener(&hp) + if err != nil { + retErr = fmt.Errorf("cannot open hostport %d for pod %s: %v", pm.HostPort, getPodFullName(podPortMapping), err) + break + } + ports[hp] = socket + } + + // If encounter any error, close all hostports that just got opened. + if retErr != nil { + for hp, socket := range ports { + if err := socket.Close(); err != nil { + glog.Errorf("Cannot clean up hostport %d for pod %s: %v", hp.port, getPodFullName(podPortMapping), err) + } + } + return nil, retErr + } + return ports, nil +} + // closeHostports tries to close all the listed host ports -// TODO: move closeHostports and openHostports into a common struct func (hm *hostportManager) closeHostports(hostportMappings []*PortMapping) error { errList := []error{} for _, pm := range hostportMappings { diff --git a/pkg/kubelet/network/hostport/hostport_manager_test.go b/pkg/kubelet/network/hostport/hostport_manager_test.go index 019d36e78d2..1b3b460cb5a 100644 --- a/pkg/kubelet/network/hostport/hostport_manager_test.go +++ b/pkg/kubelet/network/hostport/hostport_manager_test.go @@ -28,6 +28,134 @@ import ( "k8s.io/utils/exec" ) +func TestOpenCloseHostports(t *testing.T) { + openPortCases := []struct { + podPortMapping *PodPortMapping + expectError bool + }{ + { + &PodPortMapping{ + Namespace: "ns1", + Name: "n0", + }, + false, + }, + { + &PodPortMapping{ + Namespace: "ns1", + Name: "n1", + PortMappings: []*PortMapping{ + {HostPort: 80, Protocol: v1.Protocol("TCP")}, + {HostPort: 8080, Protocol: v1.Protocol("TCP")}, + {HostPort: 443, Protocol: v1.Protocol("TCP")}, + }, + }, + false, + }, + { + &PodPortMapping{ + Namespace: "ns1", + Name: "n2", + PortMappings: []*PortMapping{ + {HostPort: 80, Protocol: v1.Protocol("TCP")}, + }, + }, + true, + }, + { + &PodPortMapping{ + Namespace: "ns1", + Name: "n3", + PortMappings: []*PortMapping{ + {HostPort: 8081, Protocol: v1.Protocol("TCP")}, + {HostPort: 8080, Protocol: v1.Protocol("TCP")}, + }, + }, + true, + }, + { + &PodPortMapping{ + Namespace: "ns1", + Name: "n3", + PortMappings: []*PortMapping{ + {HostPort: 8081, Protocol: v1.Protocol("TCP")}, + }, + }, + false, + }, + } + + iptables := NewFakeIPTables() + portOpener := NewFakeSocketManager() + manager := &hostportManager{ + hostPortMap: make(map[hostport]closeable), + iptables: iptables, + portOpener: portOpener.openFakeSocket, + execer: exec.New(), + } + + for _, tc := range openPortCases { + mapping, err := manager.openHostports(tc.podPortMapping) + if tc.expectError { + assert.Error(t, err) + continue + } + assert.NoError(t, err) + assert.EqualValues(t, len(mapping), len(tc.podPortMapping.PortMappings)) + } + + // We have 4 ports: 80, 443, 8080, 8081 open now. + closePortCases := []struct { + portMappings []*PortMapping + expectError bool + }{ + { + portMappings: nil, + }, + { + + portMappings: []*PortMapping{ + {HostPort: 80, Protocol: v1.Protocol("TCP")}, + {HostPort: 8080, Protocol: v1.Protocol("TCP")}, + {HostPort: 443, Protocol: v1.Protocol("TCP")}, + }, + }, + { + + portMappings: []*PortMapping{ + {HostPort: 80, Protocol: v1.Protocol("TCP")}, + }, + }, + { + portMappings: []*PortMapping{ + {HostPort: 8081, Protocol: v1.Protocol("TCP")}, + {HostPort: 8080, Protocol: v1.Protocol("TCP")}, + }, + }, + { + portMappings: []*PortMapping{ + {HostPort: 8081, Protocol: v1.Protocol("TCP")}, + }, + }, + { + portMappings: []*PortMapping{ + {HostPort: 7070, Protocol: v1.Protocol("TCP")}, + }, + }, + } + + for _, tc := range closePortCases { + err := manager.closeHostports(tc.portMappings) + if tc.expectError { + assert.Error(t, err) + continue + } + assert.NoError(t, err) + } + // Clear all elements in hostPortMap + assert.Zero(t, len(manager.hostPortMap)) +} + func TestHostportManager(t *testing.T) { iptables := NewFakeIPTables() portOpener := NewFakeSocketManager() diff --git a/pkg/kubelet/network/hostport/hostport_test.go b/pkg/kubelet/network/hostport/hostport_test.go index 306cb6252d7..97dd907c308 100644 --- a/pkg/kubelet/network/hostport/hostport_test.go +++ b/pkg/kubelet/network/hostport/hostport_test.go @@ -21,7 +21,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "k8s.io/api/core/v1" utiliptables "k8s.io/kubernetes/pkg/util/iptables" ) @@ -56,75 +55,6 @@ func (f *fakeSocketManager) openFakeSocket(hp *hostport) (closeable, error) { return fs, nil } -func TestOpenHostports(t *testing.T) { - opener := NewFakeSocketManager() - testCases := []struct { - podPortMapping *PodPortMapping - expectError bool - }{ - { - &PodPortMapping{ - Namespace: "ns1", - Name: "n0", - }, - false, - }, - { - &PodPortMapping{ - Namespace: "ns1", - Name: "n1", - PortMappings: []*PortMapping{ - {HostPort: 80, Protocol: v1.Protocol("TCP")}, - {HostPort: 8080, Protocol: v1.Protocol("TCP")}, - {HostPort: 443, Protocol: v1.Protocol("TCP")}, - }, - }, - false, - }, - { - &PodPortMapping{ - Namespace: "ns1", - Name: "n2", - PortMappings: []*PortMapping{ - {HostPort: 80, Protocol: v1.Protocol("TCP")}, - }, - }, - true, - }, - { - &PodPortMapping{ - Namespace: "ns1", - Name: "n3", - PortMappings: []*PortMapping{ - {HostPort: 8081, Protocol: v1.Protocol("TCP")}, - {HostPort: 8080, Protocol: v1.Protocol("TCP")}, - }, - }, - true, - }, - { - &PodPortMapping{ - Namespace: "ns1", - Name: "n3", - PortMappings: []*PortMapping{ - {HostPort: 8081, Protocol: v1.Protocol("TCP")}, - }, - }, - false, - }, - } - - for _, tc := range testCases { - mapping, err := openHostports(opener.openFakeSocket, tc.podPortMapping) - if tc.expectError { - assert.Error(t, err) - continue - } - assert.NoError(t, err) - assert.EqualValues(t, len(mapping), len(tc.podPortMapping.PortMappings)) - } -} - func TestEnsureKubeHostportChains(t *testing.T) { interfaceName := "cbr0" builtinChains := []string{"PREROUTING", "OUTPUT"}