mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-23 11:50:44 +00:00
Specify a port range to ILB firewall rule create.
This commit is contained in:
parent
8af6906d1f
commit
44f0b26ab9
@ -22,6 +22,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -48,7 +49,7 @@ func (g *Cloud) ensureInternalLoadBalancer(clusterName, clusterID string, svc *v
|
||||
}
|
||||
|
||||
nm := types.NamespacedName{Name: svc.Name, Namespace: svc.Namespace}
|
||||
ports, protocol := getPortsAndProtocol(svc.Spec.Ports)
|
||||
ports, _, protocol := getPortsAndProtocol(svc.Spec.Ports)
|
||||
if protocol != v1.ProtocolTCP && protocol != v1.ProtocolUDP {
|
||||
return nil, fmt.Errorf("Invalid protocol %s, only TCP and UDP are supported", string(protocol))
|
||||
}
|
||||
@ -231,7 +232,7 @@ func (g *Cloud) updateInternalLoadBalancer(clusterName, clusterID string, svc *v
|
||||
}
|
||||
|
||||
// Generate the backend service name
|
||||
_, protocol := getPortsAndProtocol(svc.Spec.Ports)
|
||||
_, _, protocol := getPortsAndProtocol(svc.Spec.Ports)
|
||||
scheme := cloud.SchemeInternal
|
||||
loadBalancerName := g.GetLoadBalancerName(context.TODO(), clusterName, svc)
|
||||
backendServiceName := makeBackendServiceName(loadBalancerName, clusterID, shareBackendService(svc), scheme, protocol, svc.Spec.SessionAffinity)
|
||||
@ -241,7 +242,7 @@ func (g *Cloud) updateInternalLoadBalancer(clusterName, clusterID string, svc *v
|
||||
|
||||
func (g *Cloud) ensureInternalLoadBalancerDeleted(clusterName, clusterID string, svc *v1.Service) error {
|
||||
loadBalancerName := g.GetLoadBalancerName(context.TODO(), clusterName, svc)
|
||||
_, protocol := getPortsAndProtocol(svc.Spec.Ports)
|
||||
_, _, protocol := getPortsAndProtocol(svc.Spec.Ports)
|
||||
scheme := cloud.SchemeInternal
|
||||
sharedBackend := shareBackendService(svc)
|
||||
sharedHealthCheck := !servicehelpers.RequestsOnlyLocalTraffic(svc)
|
||||
@ -344,7 +345,7 @@ func (g *Cloud) teardownInternalHealthCheckAndFirewall(svc *v1.Service, hcName s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Cloud) ensureInternalFirewall(svc *v1.Service, fwName, fwDesc string, sourceRanges []string, ports []string, protocol v1.Protocol, nodes []*v1.Node, legacyFwName string) error {
|
||||
func (g *Cloud) ensureInternalFirewall(svc *v1.Service, fwName, fwDesc string, sourceRanges []string, portRanges []string, protocol v1.Protocol, nodes []*v1.Node, legacyFwName string) error {
|
||||
klog.V(2).Infof("ensureInternalFirewall(%v): checking existing firewall", fwName)
|
||||
targetTags, err := g.GetNodeTags(nodeNames(nodes))
|
||||
if err != nil {
|
||||
@ -388,7 +389,7 @@ func (g *Cloud) ensureInternalFirewall(svc *v1.Service, fwName, fwDesc string, s
|
||||
Allowed: []*compute.FirewallAllowed{
|
||||
{
|
||||
IPProtocol: strings.ToLower(string(protocol)),
|
||||
Ports: ports,
|
||||
Ports: portRanges,
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -421,12 +422,12 @@ func (g *Cloud) ensureInternalFirewall(svc *v1.Service, fwName, fwDesc string, s
|
||||
func (g *Cloud) ensureInternalFirewalls(loadBalancerName, ipAddress, clusterID string, nm types.NamespacedName, svc *v1.Service, healthCheckPort string, sharedHealthCheck bool, nodes []*v1.Node) error {
|
||||
// First firewall is for ingress traffic
|
||||
fwDesc := makeFirewallDescription(nm.String(), ipAddress)
|
||||
ports, protocol := getPortsAndProtocol(svc.Spec.Ports)
|
||||
_, portRanges, protocol := getPortsAndProtocol(svc.Spec.Ports)
|
||||
sourceRanges, err := servicehelpers.GetLoadBalancerSourceRanges(svc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = g.ensureInternalFirewall(svc, MakeFirewallName(loadBalancerName), fwDesc, sourceRanges.StringSlice(), ports, protocol, nodes, loadBalancerName)
|
||||
err = g.ensureInternalFirewall(svc, MakeFirewallName(loadBalancerName), fwDesc, sourceRanges.StringSlice(), portRanges, protocol, nodes, loadBalancerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -747,17 +748,62 @@ func backendSvcEqual(a, b *compute.BackendService) bool {
|
||||
backendsListEqual(a.Backends, b.Backends)
|
||||
}
|
||||
|
||||
func getPortsAndProtocol(svcPorts []v1.ServicePort) (ports []string, protocol v1.Protocol) {
|
||||
func getPortsAndProtocol(svcPorts []v1.ServicePort) (ports []string, portRanges []string, protocol v1.Protocol) {
|
||||
if len(svcPorts) == 0 {
|
||||
return []string{}, v1.ProtocolUDP
|
||||
return []string{}, []string{}, v1.ProtocolUDP
|
||||
}
|
||||
|
||||
// GCP doesn't support multiple protocols for a single load balancer
|
||||
protocol = svcPorts[0].Protocol
|
||||
portInts := []int{}
|
||||
for _, p := range svcPorts {
|
||||
ports = append(ports, strconv.Itoa(int(p.Port)))
|
||||
portInts = append(portInts, int(p.Port))
|
||||
}
|
||||
return ports, protocol
|
||||
|
||||
return ports, getPortRanges(portInts), protocol
|
||||
}
|
||||
|
||||
func getPortRanges(ports []int) (ranges []string) {
|
||||
if len(ports) < 1 {
|
||||
return ranges
|
||||
}
|
||||
sort.Ints(ports)
|
||||
|
||||
start := ports[0]
|
||||
prev := ports[0]
|
||||
for ix, current := range ports {
|
||||
switch {
|
||||
case current == prev:
|
||||
// Loop over duplicates, except if the end of list is reached.
|
||||
if ix == len(ports)-1 {
|
||||
if start == current {
|
||||
ranges = append(ranges, fmt.Sprintf("%d", current))
|
||||
} else {
|
||||
ranges = append(ranges, fmt.Sprintf("%d-%d", start, current))
|
||||
}
|
||||
}
|
||||
case current == prev+1:
|
||||
// continue the streak, create the range if this is the last element in the list.
|
||||
if ix == len(ports)-1 {
|
||||
ranges = append(ranges, fmt.Sprintf("%d-%d", start, current))
|
||||
}
|
||||
default:
|
||||
// current is not prev + 1, streak is broken. Construct the range and handle last element case.
|
||||
if start == prev {
|
||||
ranges = append(ranges, fmt.Sprintf("%d", prev))
|
||||
} else {
|
||||
ranges = append(ranges, fmt.Sprintf("%d-%d", start, prev))
|
||||
}
|
||||
if ix == len(ports)-1 {
|
||||
ranges = append(ranges, fmt.Sprintf("%d", current))
|
||||
}
|
||||
// reset start element
|
||||
start = current
|
||||
}
|
||||
prev = current
|
||||
}
|
||||
return ranges
|
||||
}
|
||||
|
||||
func (g *Cloud) getBackendServiceLink(name string) string {
|
||||
|
@ -21,6 +21,7 @@ package gce
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -1395,3 +1396,71 @@ func TestEnsureInternalLoadBalancerCustomSubnet(t *testing.T) {
|
||||
}
|
||||
assertInternalLbResourcesDeleted(t, gce, svc, vals, true)
|
||||
}
|
||||
|
||||
func TestGetPortRanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tc := range []struct {
|
||||
Desc string
|
||||
Input []int
|
||||
Result []string
|
||||
}{
|
||||
{Desc: "All Unique", Input: []int{8, 66, 23, 13, 89}, Result: []string{"8", "13", "23", "66", "89"}},
|
||||
{Desc: "All Unique Sorted", Input: []int{1, 7, 9, 16, 26}, Result: []string{"1", "7", "9", "16", "26"}},
|
||||
{Desc: "Ranges", Input: []int{56, 78, 67, 79, 21, 80, 12}, Result: []string{"12", "21", "56", "67", "78-80"}},
|
||||
{Desc: "Ranges Sorted", Input: []int{5, 7, 90, 1002, 1003, 1004, 1005, 2501}, Result: []string{"5", "7", "90", "1002-1005", "2501"}},
|
||||
{Desc: "Ranges Duplicates", Input: []int{15, 37, 900, 2002, 2003, 2003, 2004, 2004}, Result: []string{"15", "37", "900", "2002-2004"}},
|
||||
{Desc: "Duplicates", Input: []int{10, 10, 10, 10, 10}, Result: []string{"10"}},
|
||||
{Desc: "Only ranges", Input: []int{18, 19, 20, 21, 22, 55, 56, 77, 78, 79, 3504, 3505, 3506}, Result: []string{"18-22", "55-56", "77-79", "3504-3506"}},
|
||||
{Desc: "Single Range", Input: []int{6000, 6001, 6002, 6003, 6004, 6005}, Result: []string{"6000-6005"}},
|
||||
{Desc: "One value", Input: []int{12}, Result: []string{"12"}},
|
||||
{Desc: "Empty", Input: []int{}, Result: nil},
|
||||
} {
|
||||
result := getPortRanges(tc.Input)
|
||||
if !reflect.DeepEqual(result, tc.Result) {
|
||||
t.Errorf("Expected %v, got %v for test case %s", tc.Result, result, tc.Desc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureInternalFirewallPortRanges(t *testing.T) {
|
||||
gce, err := fakeGCECloud(DefaultTestClusterValues())
|
||||
require.NoError(t, err)
|
||||
vals := DefaultTestClusterValues()
|
||||
svc := fakeLoadbalancerService(string(LBTypeInternal))
|
||||
lbName := gce.GetLoadBalancerName(context.TODO(), "", svc)
|
||||
fwName := MakeFirewallName(lbName)
|
||||
tc := struct {
|
||||
Input []int
|
||||
Result []string
|
||||
}{
|
||||
Input: []int{15, 37, 900, 2002, 2003, 2003, 2004, 2004}, Result: []string{"15", "37", "900", "2002-2004"},
|
||||
}
|
||||
c := gce.c.(*cloud.MockGCE)
|
||||
c.MockFirewalls.InsertHook = nil
|
||||
c.MockFirewalls.UpdateHook = nil
|
||||
|
||||
nodes, err := createAndInsertNodes(gce, []string{"test-node-1"}, vals.ZoneName)
|
||||
require.NoError(t, err)
|
||||
sourceRange := []string{"10.0.0.0/20"}
|
||||
// Manually create a firewall rule with the legacy name - lbName
|
||||
gce.ensureInternalFirewall(
|
||||
svc,
|
||||
fwName,
|
||||
"firewall with legacy name",
|
||||
sourceRange,
|
||||
getPortRanges(tc.Input),
|
||||
v1.ProtocolTCP,
|
||||
nodes,
|
||||
"")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error %v when ensuring legacy firewall %s for svc %+v", err, lbName, svc)
|
||||
}
|
||||
existingFirewall, err := gce.GetFirewall(fwName)
|
||||
if err != nil || existingFirewall == nil || len(existingFirewall.Allowed) == 0 {
|
||||
t.Errorf("Unexpected error %v when looking up firewall %s, Got firewall %+v", err, fwName, existingFirewall)
|
||||
}
|
||||
existingPorts := existingFirewall.Allowed[0].Ports
|
||||
if !reflect.DeepEqual(existingPorts, tc.Result) {
|
||||
t.Errorf("Expected firewall rule with ports %v,got %v", tc.Result, existingPorts)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user