Specify a port range to ILB firewall rule create.

This commit is contained in:
Pavithra Ramesh 2019-11-13 17:15:59 -08:00
parent 8af6906d1f
commit 44f0b26ab9
2 changed files with 125 additions and 10 deletions

View File

@ -22,6 +22,7 @@ import (
"context"
"encoding/json"
"fmt"
"sort"
"strconv"
"strings"
@ -48,7 +49,7 @@ func (g *Cloud) ensureInternalLoadBalancer(clusterName, clusterID string, svc *v
}
nm := types.NamespacedName{Name: svc.Name, Namespace: svc.Namespace}
ports, protocol := getPortsAndProtocol(svc.Spec.Ports)
ports, _, protocol := getPortsAndProtocol(svc.Spec.Ports)
if protocol != v1.ProtocolTCP && protocol != v1.ProtocolUDP {
return nil, fmt.Errorf("Invalid protocol %s, only TCP and UDP are supported", string(protocol))
}
@ -231,7 +232,7 @@ func (g *Cloud) updateInternalLoadBalancer(clusterName, clusterID string, svc *v
}
// Generate the backend service name
_, protocol := getPortsAndProtocol(svc.Spec.Ports)
_, _, protocol := getPortsAndProtocol(svc.Spec.Ports)
scheme := cloud.SchemeInternal
loadBalancerName := g.GetLoadBalancerName(context.TODO(), clusterName, svc)
backendServiceName := makeBackendServiceName(loadBalancerName, clusterID, shareBackendService(svc), scheme, protocol, svc.Spec.SessionAffinity)
@ -241,7 +242,7 @@ func (g *Cloud) updateInternalLoadBalancer(clusterName, clusterID string, svc *v
func (g *Cloud) ensureInternalLoadBalancerDeleted(clusterName, clusterID string, svc *v1.Service) error {
loadBalancerName := g.GetLoadBalancerName(context.TODO(), clusterName, svc)
_, protocol := getPortsAndProtocol(svc.Spec.Ports)
_, _, protocol := getPortsAndProtocol(svc.Spec.Ports)
scheme := cloud.SchemeInternal
sharedBackend := shareBackendService(svc)
sharedHealthCheck := !servicehelpers.RequestsOnlyLocalTraffic(svc)
@ -344,7 +345,7 @@ func (g *Cloud) teardownInternalHealthCheckAndFirewall(svc *v1.Service, hcName s
return nil
}
func (g *Cloud) ensureInternalFirewall(svc *v1.Service, fwName, fwDesc string, sourceRanges []string, ports []string, protocol v1.Protocol, nodes []*v1.Node, legacyFwName string) error {
func (g *Cloud) ensureInternalFirewall(svc *v1.Service, fwName, fwDesc string, sourceRanges []string, portRanges []string, protocol v1.Protocol, nodes []*v1.Node, legacyFwName string) error {
klog.V(2).Infof("ensureInternalFirewall(%v): checking existing firewall", fwName)
targetTags, err := g.GetNodeTags(nodeNames(nodes))
if err != nil {
@ -388,7 +389,7 @@ func (g *Cloud) ensureInternalFirewall(svc *v1.Service, fwName, fwDesc string, s
Allowed: []*compute.FirewallAllowed{
{
IPProtocol: strings.ToLower(string(protocol)),
Ports: ports,
Ports: portRanges,
},
},
}
@ -421,12 +422,12 @@ func (g *Cloud) ensureInternalFirewall(svc *v1.Service, fwName, fwDesc string, s
func (g *Cloud) ensureInternalFirewalls(loadBalancerName, ipAddress, clusterID string, nm types.NamespacedName, svc *v1.Service, healthCheckPort string, sharedHealthCheck bool, nodes []*v1.Node) error {
// First firewall is for ingress traffic
fwDesc := makeFirewallDescription(nm.String(), ipAddress)
ports, protocol := getPortsAndProtocol(svc.Spec.Ports)
_, portRanges, protocol := getPortsAndProtocol(svc.Spec.Ports)
sourceRanges, err := servicehelpers.GetLoadBalancerSourceRanges(svc)
if err != nil {
return err
}
err = g.ensureInternalFirewall(svc, MakeFirewallName(loadBalancerName), fwDesc, sourceRanges.StringSlice(), ports, protocol, nodes, loadBalancerName)
err = g.ensureInternalFirewall(svc, MakeFirewallName(loadBalancerName), fwDesc, sourceRanges.StringSlice(), portRanges, protocol, nodes, loadBalancerName)
if err != nil {
return err
}
@ -747,17 +748,62 @@ func backendSvcEqual(a, b *compute.BackendService) bool {
backendsListEqual(a.Backends, b.Backends)
}
func getPortsAndProtocol(svcPorts []v1.ServicePort) (ports []string, protocol v1.Protocol) {
func getPortsAndProtocol(svcPorts []v1.ServicePort) (ports []string, portRanges []string, protocol v1.Protocol) {
if len(svcPorts) == 0 {
return []string{}, v1.ProtocolUDP
return []string{}, []string{}, v1.ProtocolUDP
}
// GCP doesn't support multiple protocols for a single load balancer
protocol = svcPorts[0].Protocol
portInts := []int{}
for _, p := range svcPorts {
ports = append(ports, strconv.Itoa(int(p.Port)))
portInts = append(portInts, int(p.Port))
}
return ports, protocol
return ports, getPortRanges(portInts), protocol
}
func getPortRanges(ports []int) (ranges []string) {
if len(ports) < 1 {
return ranges
}
sort.Ints(ports)
start := ports[0]
prev := ports[0]
for ix, current := range ports {
switch {
case current == prev:
// Loop over duplicates, except if the end of list is reached.
if ix == len(ports)-1 {
if start == current {
ranges = append(ranges, fmt.Sprintf("%d", current))
} else {
ranges = append(ranges, fmt.Sprintf("%d-%d", start, current))
}
}
case current == prev+1:
// continue the streak, create the range if this is the last element in the list.
if ix == len(ports)-1 {
ranges = append(ranges, fmt.Sprintf("%d-%d", start, current))
}
default:
// current is not prev + 1, streak is broken. Construct the range and handle last element case.
if start == prev {
ranges = append(ranges, fmt.Sprintf("%d", prev))
} else {
ranges = append(ranges, fmt.Sprintf("%d-%d", start, prev))
}
if ix == len(ports)-1 {
ranges = append(ranges, fmt.Sprintf("%d", current))
}
// reset start element
start = current
}
prev = current
}
return ranges
}
func (g *Cloud) getBackendServiceLink(name string) string {

View File

@ -21,6 +21,7 @@ package gce
import (
"context"
"fmt"
"reflect"
"strings"
"testing"
@ -1395,3 +1396,71 @@ func TestEnsureInternalLoadBalancerCustomSubnet(t *testing.T) {
}
assertInternalLbResourcesDeleted(t, gce, svc, vals, true)
}
func TestGetPortRanges(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
Desc string
Input []int
Result []string
}{
{Desc: "All Unique", Input: []int{8, 66, 23, 13, 89}, Result: []string{"8", "13", "23", "66", "89"}},
{Desc: "All Unique Sorted", Input: []int{1, 7, 9, 16, 26}, Result: []string{"1", "7", "9", "16", "26"}},
{Desc: "Ranges", Input: []int{56, 78, 67, 79, 21, 80, 12}, Result: []string{"12", "21", "56", "67", "78-80"}},
{Desc: "Ranges Sorted", Input: []int{5, 7, 90, 1002, 1003, 1004, 1005, 2501}, Result: []string{"5", "7", "90", "1002-1005", "2501"}},
{Desc: "Ranges Duplicates", Input: []int{15, 37, 900, 2002, 2003, 2003, 2004, 2004}, Result: []string{"15", "37", "900", "2002-2004"}},
{Desc: "Duplicates", Input: []int{10, 10, 10, 10, 10}, Result: []string{"10"}},
{Desc: "Only ranges", Input: []int{18, 19, 20, 21, 22, 55, 56, 77, 78, 79, 3504, 3505, 3506}, Result: []string{"18-22", "55-56", "77-79", "3504-3506"}},
{Desc: "Single Range", Input: []int{6000, 6001, 6002, 6003, 6004, 6005}, Result: []string{"6000-6005"}},
{Desc: "One value", Input: []int{12}, Result: []string{"12"}},
{Desc: "Empty", Input: []int{}, Result: nil},
} {
result := getPortRanges(tc.Input)
if !reflect.DeepEqual(result, tc.Result) {
t.Errorf("Expected %v, got %v for test case %s", tc.Result, result, tc.Desc)
}
}
}
func TestEnsureInternalFirewallPortRanges(t *testing.T) {
gce, err := fakeGCECloud(DefaultTestClusterValues())
require.NoError(t, err)
vals := DefaultTestClusterValues()
svc := fakeLoadbalancerService(string(LBTypeInternal))
lbName := gce.GetLoadBalancerName(context.TODO(), "", svc)
fwName := MakeFirewallName(lbName)
tc := struct {
Input []int
Result []string
}{
Input: []int{15, 37, 900, 2002, 2003, 2003, 2004, 2004}, Result: []string{"15", "37", "900", "2002-2004"},
}
c := gce.c.(*cloud.MockGCE)
c.MockFirewalls.InsertHook = nil
c.MockFirewalls.UpdateHook = nil
nodes, err := createAndInsertNodes(gce, []string{"test-node-1"}, vals.ZoneName)
require.NoError(t, err)
sourceRange := []string{"10.0.0.0/20"}
// Manually create a firewall rule with the legacy name - lbName
gce.ensureInternalFirewall(
svc,
fwName,
"firewall with legacy name",
sourceRange,
getPortRanges(tc.Input),
v1.ProtocolTCP,
nodes,
"")
if err != nil {
t.Errorf("Unexpected error %v when ensuring legacy firewall %s for svc %+v", err, lbName, svc)
}
existingFirewall, err := gce.GetFirewall(fwName)
if err != nil || existingFirewall == nil || len(existingFirewall.Allowed) == 0 {
t.Errorf("Unexpected error %v when looking up firewall %s, Got firewall %+v", err, fwName, existingFirewall)
}
existingPorts := existingFirewall.Allowed[0].Ports
if !reflect.DeepEqual(existingPorts, tc.Result) {
t.Errorf("Expected firewall rule with ports %v,got %v", tc.Result, existingPorts)
}
}