diff --git a/pkg/cloudprovider/providers/aws/aws_loadbalancer.go b/pkg/cloudprovider/providers/aws/aws_loadbalancer.go index 0c2e7553067..3242a057bc3 100644 --- a/pkg/cloudprovider/providers/aws/aws_loadbalancer.go +++ b/pkg/cloudprovider/providers/aws/aws_loadbalancer.go @@ -569,12 +569,17 @@ func filterForIPRangeDescription(securityGroups []*ec2.SecurityGroup, lbName str response := []*ec2.SecurityGroup{} clientRule := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName) healthRule := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName) + alreadyAdded := sets.NewString() for i := range securityGroups { for j := range securityGroups[i].IpPermissions { for k := range securityGroups[i].IpPermissions[j].IpRanges { description := aws.StringValue(securityGroups[i].IpPermissions[j].IpRanges[k].Description) if description == clientRule || description == healthRule { - response = append(response, securityGroups[i]) + sgIDString := aws.StringValue(securityGroups[i].GroupId) + if !alreadyAdded.Has(sgIDString) { + response = append(response, securityGroups[i]) + alreadyAdded.Insert(sgIDString) + } } } } @@ -599,6 +604,7 @@ func (c *Cloud) getVpcCidrBlock() (*string, error) { // if clientTraffic is false, then only update HealthCheck rules func (c *Cloud) updateInstanceSecurityGroupsForNLBTraffic(actualGroups []*ec2.SecurityGroup, desiredSgIds []string, ports []int64, lbName string, clientCidrs []string, clientTraffic bool) error { + klog.V(8).Infof("updateInstanceSecurityGroupsForNLBTraffic: actualGroups=%v, desiredSgIds=%v, ports=%v, clientTraffic=%v", actualGroups, desiredSgIds, ports, clientTraffic) // Map containing the groups we want to make changes on; the ports to make // changes on; and whether to add or remove it. true to add, false to remove portChanges := map[string]map[int64]bool{} @@ -653,16 +659,16 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLBTraffic(actualGroups []*ec2.Se if add { if clientTraffic { klog.V(2).Infof("Adding rule for client MTU discovery from the network load balancer (%s) to instances (%s)", clientCidrs, instanceSecurityGroupID) - klog.V(2).Infof("Adding rule for client traffic from the network load balancer (%s) to instances (%s)", clientCidrs, instanceSecurityGroupID) + klog.V(2).Infof("Adding rule for client traffic from the network load balancer (%s) to instances (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port) } else { - klog.V(2).Infof("Adding rule for health check traffic from the network load balancer (%s) to instances (%s)", clientCidrs, instanceSecurityGroupID) + klog.V(2).Infof("Adding rule for health check traffic from the network load balancer (%s) to instances (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port) } } else { if clientTraffic { klog.V(2).Infof("Removing rule for client MTU discovery from the network load balancer (%s) to instances (%s)", clientCidrs, instanceSecurityGroupID) - klog.V(2).Infof("Removing rule for client traffic from the network load balancer (%s) to instance (%s)", clientCidrs, instanceSecurityGroupID) + klog.V(2).Infof("Removing rule for client traffic from the network load balancer (%s) to instance (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port) } - klog.V(2).Infof("Removing rule for health check traffic from the network load balancer (%s) to instance (%s)", clientCidrs, instanceSecurityGroupID) + klog.V(2).Infof("Removing rule for health check traffic from the network load balancer (%s) to instance (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port) } if clientTraffic { diff --git a/pkg/cloudprovider/providers/aws/aws_loadbalancer_test.go b/pkg/cloudprovider/providers/aws/aws_loadbalancer_test.go index 9f81ea75cfa..cd5d79ba7a9 100644 --- a/pkg/cloudprovider/providers/aws/aws_loadbalancer_test.go +++ b/pkg/cloudprovider/providers/aws/aws_loadbalancer_test.go @@ -17,9 +17,11 @@ limitations under the License. package aws import ( + "fmt" "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" ) func TestElbProtocolsAreEqual(t *testing.T) { @@ -160,3 +162,63 @@ func TestIsNLB(t *testing.T) { } } } + +func TestSecurityGroupFiltering(t *testing.T) { + grid := []struct { + in []*ec2.SecurityGroup + name string + expected int + description string + }{ + { + in: []*ec2.SecurityGroup{ + { + IpPermissions: []*ec2.IpPermission{ + { + IpRanges: []*ec2.IpRange{ + { + Description: aws.String("an unmanaged"), + }, + }, + }, + }, + }, + }, + name: "unmanaged", + expected: 0, + description: "An environment without managed LBs should have %d, but found %d SecurityGroups", + }, + { + in: []*ec2.SecurityGroup{ + { + IpPermissions: []*ec2.IpPermission{ + { + IpRanges: []*ec2.IpRange{ + { + Description: aws.String("an unmanaged"), + }, + { + Description: aws.String(fmt.Sprintf("%s=%s", NLBClientRuleDescription, "managedlb")), + }, + { + Description: aws.String(fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, "managedlb")), + }, + }, + }, + }, + }, + }, + name: "managedlb", + expected: 1, + description: "Found %d, but should have %d Security Groups", + }, + } + + for _, g := range grid { + actual := len(filterForIPRangeDescription(g.in, g.name)) + if actual != g.expected { + t.Errorf(g.description, actual, g.expected) + } + } + +}