diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go index 5d421874f91..89b70c6aa8e 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go @@ -210,6 +210,12 @@ const ServiceAnnotationLoadBalancerHCInterval = "service.beta.kubernetes.io/aws- // static IP addresses for the NLB. Only supported on elbv2 (NLB) const ServiceAnnotationLoadBalancerEIPAllocations = "service.beta.kubernetes.io/aws-load-balancer-eip-allocations" +// ServiceAnnotationLoadBalancerTargetNodeLabels is the annotation used on the service +// to specify a comma-separated list of key-value pairs which will be used to select +// the target nodes for the load balancer +// For example: "Key1=Val1,Key2=Val2,KeyNoVal1=,KeyNoVal2" +const ServiceAnnotationLoadBalancerTargetNodeLabels = "service.beta.kubernetes.io/aws-load-balancer-target-node-labels" + // Event key when a volume is stuck on attaching state when being attached to a volume const volumeAttachmentStuck = "VolumeAttachmentStuck" @@ -3568,7 +3574,7 @@ func (c *Cloud) buildELBSecurityGroupList(serviceName types.NamespacedName, load // Create a security group for the load balancer sgName := "k8s-elb-" + loadBalancerName sgDescription := fmt.Sprintf("Security group for Kubernetes ELB %s (%v)", loadBalancerName, serviceName) - securityGroupID, err = c.ensureSecurityGroup(sgName, sgDescription, getLoadBalancerAdditionalTags(annotations)) + securityGroupID, err = c.ensureSecurityGroup(sgName, sgDescription, getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags)) if err != nil { klog.Errorf("Error creating load balancer security group: %q", err) return nil, setupSg, err @@ -3686,7 +3692,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS return nil, fmt.Errorf("LoadBalancerIP cannot be specified for AWS ELB") } - instances, err := c.findInstancesForELB(nodes) + instances, err := c.findInstancesForELB(nodes, annotations) if err != nil { return nil, err } @@ -4470,7 +4476,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin // UpdateLoadBalancer implements LoadBalancer.UpdateLoadBalancer func (c *Cloud) UpdateLoadBalancer(ctx context.Context, clusterName string, service *v1.Service, nodes []*v1.Node) error { - instances, err := c.findInstancesForELB(nodes) + instances, err := c.findInstancesForELB(nodes, service.Annotations) if err != nil { return err } diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go index a118400764a..ac039c5e0ac 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go @@ -84,12 +84,11 @@ type nlbPortMapping struct { SSLPolicy string } -// getLoadBalancerAdditionalTags converts the comma separated list of key-value -// pairs in the ServiceAnnotationLoadBalancerAdditionalTags annotation and returns -// it as a map. -func getLoadBalancerAdditionalTags(annotations map[string]string) map[string]string { +// getKeyValuePropertiesFromAnnotation converts the comma separated list of key-value +// pairs from the specified annotation and returns it as a map. +func getKeyValuePropertiesFromAnnotation(annotations map[string]string, annotation string) map[string]string { additionalTags := make(map[string]string) - if additionalTagsList, ok := annotations[ServiceAnnotationLoadBalancerAdditionalTags]; ok { + if additionalTagsList, ok := annotations[annotation]; ok { additionalTagsList = strings.TrimSpace(additionalTagsList) // Break up list of "Key1=Val,Key2=Val2" @@ -123,7 +122,7 @@ func (c *Cloud) ensureLoadBalancerv2(namespacedName types.NamespacedName, loadBa dirty := false // Get additional tags set by the user - tags := getLoadBalancerAdditionalTags(annotations) + tags := getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags) // Add default tags tags[TagNameKubernetesService] = namespacedName.String() tags = c.tagging.buildTags(ResourceLifecycleOwned, tags) @@ -939,7 +938,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala } // Get additional tags set by the user - tags := getLoadBalancerAdditionalTags(annotations) + tags := getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags) // Add default tags tags[TagNameKubernetesService] = namespacedName.String() @@ -1128,7 +1127,7 @@ func (c *Cloud) ensureLoadBalancer(namespacedName types.NamespacedName, loadBala { // Add additional tags klog.V(2).Infof("Creating additional load balancer tags for %s", loadBalancerName) - tags := getLoadBalancerAdditionalTags(annotations) + tags := getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerAdditionalTags) if len(tags) > 0 { err := c.addLoadBalancerTags(loadBalancerName, tags) if err != nil { @@ -1521,9 +1520,12 @@ func proxyProtocolEnabled(backend *elb.BackendServerDescription) bool { // findInstancesForELB gets the EC2 instances corresponding to the Nodes, for setting up an ELB // We ignore Nodes (with a log message) where the instanceid cannot be determined from the provider, // and we ignore instances which are not found -func (c *Cloud) findInstancesForELB(nodes []*v1.Node) (map[InstanceID]*ec2.Instance, error) { +func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]string) (map[InstanceID]*ec2.Instance, error) { + + targetNodes := filterTargetNodes(nodes, annotations) + // Map to instance ids ignoring Nodes where we cannot find the id (but logging) - instanceIDs := mapToAWSInstanceIDsTolerant(nodes) + instanceIDs := mapToAWSInstanceIDsTolerant(targetNodes) cacheCriteria := cacheCriteria{ // MaxAge not required, because we only care about security groups, which should not change @@ -1539,3 +1541,35 @@ func (c *Cloud) findInstancesForELB(nodes []*v1.Node) (map[InstanceID]*ec2.Insta return instances, nil } + +// filterTargetNodes uses node labels to filter the nodes that should be targeted by the ELB, +// checking if all the labels provided in an annotation are present in the nodes +func filterTargetNodes(nodes []*v1.Node, annotations map[string]string) []*v1.Node { + + targetNodeLabels := getKeyValuePropertiesFromAnnotation(annotations, ServiceAnnotationLoadBalancerTargetNodeLabels) + + if len(targetNodeLabels) == 0 { + return nodes + } + + targetNodes := make([]*v1.Node, 0, len(nodes)) + + for _, node := range nodes { + if node.Labels != nil && len(node.Labels) > 0 { + allFiltersMatch := true + + for targetLabelKey, targetLabelValue := range targetNodeLabels { + if nodeLabelValue, ok := node.Labels[targetLabelKey]; !ok || (nodeLabelValue != targetLabelValue && targetLabelValue != "") { + allFiltersMatch = false + break + } + } + + if allFiltersMatch { + targetNodes = append(targetNodes, node) + } + } + } + + return targetNodes +} diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer_test.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer_test.go index 16574b058c0..cfee659300a 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer_test.go @@ -25,6 +25,8 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/elb" "github.com/stretchr/testify/assert" + + "k8s.io/api/core/v1" ) func TestElbProtocolsAreEqual(t *testing.T) { @@ -420,3 +422,57 @@ func TestBuildTargetGroupName(t *testing.T) { }) } } + +func TestFilterTargetNodes(t *testing.T) { + tests := []struct { + name string + nodeLabels, annotations map[string]string + nodeTargeted bool + }{ + { + name: "when no filter is provided, node should be targeted", + nodeLabels: map[string]string{"k1": "v1"}, + nodeTargeted: true, + }, + { + name: "when all key-value filters match, node should be targeted", + nodeLabels: map[string]string{"k1": "v1", "k2": "v2"}, + annotations: map[string]string{ServiceAnnotationLoadBalancerTargetNodeLabels: "k1=v1,k2=v2"}, + nodeTargeted: true, + }, + { + name: "when all just-key filter match, node should be targeted", + nodeLabels: map[string]string{"k1": "v1", "k2": "v2"}, + annotations: map[string]string{ServiceAnnotationLoadBalancerTargetNodeLabels: "k1,k2"}, + nodeTargeted: true, + }, + { + name: "when some filters do not match, node should not be targeted", + nodeLabels: map[string]string{"k1": "v1"}, + annotations: map[string]string{ServiceAnnotationLoadBalancerTargetNodeLabels: "k1=v1,k2"}, + nodeTargeted: false, + }, + { + name: "when no filter matches, node should not be targeted", + nodeLabels: map[string]string{"k1": "v1", "k2": "v2"}, + annotations: map[string]string{ServiceAnnotationLoadBalancerTargetNodeLabels: "k3=v3"}, + nodeTargeted: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + node := &v1.Node{} + node.Labels = test.nodeLabels + + nodes := []*v1.Node{node} + targetNodes := filterTargetNodes(nodes, test.annotations) + + if test.nodeTargeted { + assert.Equal(t, nodes, targetNodes) + } else { + assert.Empty(t, targetNodes) + } + }) + } +} diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go index bd045ecce92..a0dac1d5d6d 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_test.go @@ -1530,7 +1530,7 @@ func TestProxyProtocolEnabled(t *testing.T) { assert.False(t, result, "did not expect to find %s in %s", ProxyProtocolPolicyName, policies) } -func TestGetLoadBalancerAdditionalTags(t *testing.T) { +func TestGetKeyValuePropertiesFromAnnotation(t *testing.T) { tagTests := []struct { Annotations map[string]string Tags map[string]string @@ -1581,7 +1581,7 @@ func TestGetLoadBalancerAdditionalTags(t *testing.T) { } for _, tagTest := range tagTests { - result := getLoadBalancerAdditionalTags(tagTest.Annotations) + result := getKeyValuePropertiesFromAnnotation(tagTest.Annotations, ServiceAnnotationLoadBalancerAdditionalTags) for k, v := range result { if len(result) != len(tagTest.Tags) { t.Errorf("incorrect expected length: %v != %v", result, tagTest.Tags)