From f76c21cce6237b89f81883cef27b182f825dbab5 Mon Sep 17 00:00:00 2001 From: Kishor Joshi Date: Fri, 12 Jun 2020 12:57:37 -0700 Subject: [PATCH] Allow UDP for AWS NLB Co-authored-by: Patrick Ryan Co-authored-by: Owen Ou --- .../k8s.io/legacy-cloud-providers/aws/aws.go | 30 ++++++++---- .../aws/aws_loadbalancer.go | 40 ++++++++++------ .../legacy-cloud-providers/aws/aws_test.go | 47 +++++++++++++++++++ 3 files changed, 95 insertions(+), 22 deletions(-) diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go index 901444df47d..9bcfc71a23f 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go @@ -3660,9 +3660,10 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS sslPorts := getPortSets(annotations[ServiceAnnotationLoadBalancerSSLPorts]) for _, port := range apiService.Spec.Ports { - if port.Protocol != v1.ProtocolTCP { - return nil, fmt.Errorf("Only TCP LoadBalancer is supported for AWS ELB") + if err := checkProtocol(port, annotations); err != nil { + return nil, err } + if port.NodePort == 0 { klog.Errorf("Ignoring port without NodePort defined: %v", port) continue @@ -3682,7 +3683,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } certificateARN := annotations[ServiceAnnotationLoadBalancerCertificate] - if certificateARN != "" && (sslPorts == nil || sslPorts.numbers.Has(int64(port.Port)) || sslPorts.names.Has(port.Name)) { + if port.Protocol != v1.ProtocolUDP && certificateARN != "" && (sslPorts == nil || sslPorts.numbers.Has(int64(port.Port)) || sslPorts.names.Has(port.Name)) { portMapping.FrontendProtocol = elbv2.ProtocolEnumTls portMapping.SSLCertificateARN = certificateARN portMapping.SSLPolicy = annotations[ServiceAnnotationLoadBalancerSSLNegotiationPolicy] @@ -3693,12 +3694,13 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS } v2Mappings = append(v2Mappings, portMapping) + } else { + listener, err := buildListener(port, annotations, sslPorts) + if err != nil { + return nil, err + } + listeners = append(listeners, listener) } - listener, err := buildListener(port, annotations, sslPorts) - if err != nil { - return nil, err - } - listeners = append(listeners, listener) } if apiService.Spec.LoadBalancerIP != "" { @@ -4739,6 +4741,18 @@ func (c *Cloud) nodeNameToProviderID(nodeName types.NodeName) (InstanceID, error return KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() } +func checkProtocol(port v1.ServicePort, annotations map[string]string) error { + // nlb supports tcp, udp + if isNLB(annotations) && (port.Protocol == v1.ProtocolTCP || port.Protocol == v1.ProtocolUDP) { + return nil + } + // elb only supports tcp + if !isNLB(annotations) && port.Protocol == v1.ProtocolTCP { + return nil + } + return fmt.Errorf("Protocol %s not supported by LoadBalancer", port.Protocol) +} + func setNodeDisk( nodeDiskMap map[types.NodeName]map[KubernetesVolumeID]bool, volumeID KubernetesVolumeID, diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go index ac039c5e0ac..cd3b3c6142d 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go @@ -185,9 +185,12 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa } // actual maps FrontendPort to an elbv2.Listener - actual := map[int64]*elbv2.Listener{} + actual := map[int64]map[string]*elbv2.Listener{} for _, listener := range listenerDescriptions.Listeners { - actual[*listener.Port] = listener + if actual[*listener.Port] == nil { + actual[*listener.Port] = map[string]*elbv2.Listener{} + } + actual[*listener.Port][*listener.Protocol] = listener } actualTargetGroups, err := c.elbv2.DescribeTargetGroups( @@ -207,10 +210,11 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa // Handle additions/modifications for _, mapping := range mappings { frontendPort := mapping.FrontendPort + frontendProtocol := mapping.FrontendProtocol nodePort := mapping.TrafficPort // modifications - if listener, ok := actual[frontendPort]; ok { + if listener, ok := actual[frontendPort][frontendProtocol]; ok { listenerNeedsModification := false if aws.StringValue(listener.Protocol) != mapping.FrontendProtocol { @@ -315,23 +319,27 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa dirty = true } - frontEndPorts := map[int64]bool{} + frontEndPorts := map[int64]map[string]bool{} for i := range mappings { - frontEndPorts[mappings[i].FrontendPort] = true + if frontEndPorts[mappings[i].FrontendPort] == nil { + frontEndPorts[mappings[i].FrontendPort] = map[string]bool{} + } + frontEndPorts[mappings[i].FrontendPort][mappings[i].FrontendProtocol] = true } // handle deletions - for port, listener := range actual { - if _, ok := frontEndPorts[port]; !ok { - err := c.deleteListenerV2(listener) - if err != nil { - return nil, err + for port := range actual { + for protocol := range actual[port] { + if _, ok := frontEndPorts[port][protocol]; !ok { + err := c.deleteListenerV2(actual[port][protocol]) + if err != nil { + return nil, err + } + dirty = true } - dirty = true } } } - if err := c.reconcileLBAttributes(aws.StringValue(loadBalancer.LoadBalancerArn), annotations); err != nil { return nil, err } @@ -765,10 +773,14 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ { clientPorts := sets.Int64{} + clientProtocol := "tcp" healthCheckPorts := sets.Int64{} for _, port := range portMappings { clientPorts.Insert(port.TrafficPort) healthCheckPorts.Insert(port.HealthCheckPort) + if port.TrafficProtocol == string(v1.ProtocolUDP) { + clientProtocol = "udp" + } } clientRuleAnnotation := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName) healthRuleAnnotation := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName) @@ -782,14 +794,14 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[ if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, vpcCIDRs); err != nil { return err } - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, "tcp", clientPorts, clientCIDRs); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, clientPorts, clientCIDRs); err != nil { return err } } else { if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", nil, nil); err != nil { return err } - if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, "tcp", nil, nil); err != nil { + if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, clientProtocol, nil, nil); err != nil { return err } } diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go index 28c6d56217d..7867790f507 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go @@ -1371,6 +1371,53 @@ func TestDescribeLoadBalancerOnEnsure(t *testing.T) { c.EnsureLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{}) } +func TestCheckProtocol(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + port v1.ServicePort + wantErr error + }{ + { + name: "TCP with ELB", + annotations: make(map[string]string), + port: v1.ServicePort{Protocol: v1.ProtocolTCP, Port: int32(8080)}, + wantErr: nil, + }, + { + name: "TCP with NLB", + annotations: map[string]string{ServiceAnnotationLoadBalancerType: "nlb"}, + port: v1.ServicePort{Protocol: v1.ProtocolTCP, Port: int32(8080)}, + wantErr: nil, + }, + { + name: "UDP with ELB", + annotations: make(map[string]string), + port: v1.ServicePort{Protocol: v1.ProtocolUDP, Port: int32(8080)}, + wantErr: fmt.Errorf("Protocol UDP not supported by load balancer"), + }, + { + name: "UDP with NLB", + annotations: map[string]string{ServiceAnnotationLoadBalancerType: "nlb"}, + port: v1.ServicePort{Protocol: v1.ProtocolUDP, Port: int32(8080)}, + wantErr: nil, + }, + } + for _, test := range tests { + tt := test + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := checkProtocol(tt.port, tt.annotations) + if tt.wantErr != nil && err == nil { + t.Errorf("Expected error: want=%s got =%s", tt.wantErr, err) + } + if tt.wantErr == nil && err != nil { + t.Errorf("Unexpected error: want=%s got =%s", tt.wantErr, err) + } + }) + } +} + func TestBuildListener(t *testing.T) { tests := []struct { name string