diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go index 07f0bc064cb..d74f89c7e15 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go @@ -3579,7 +3579,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS sourceRangeCidrs = append(sourceRangeCidrs, "0.0.0.0/0") } - err = c.updateInstanceSecurityGroupsForNLB(v2Mappings, instances, loadBalancerName, sourceRangeCidrs) + err = c.updateInstanceSecurityGroupsForNLB(loadBalancerName, instances, sourceRangeCidrs, v2Mappings) if err != nil { klog.Warningf("Error opening ingress rules for the load balancer to the instances: %q", err) return nil, err @@ -4158,99 +4158,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin } } - { - var matchingGroups []*ec2.SecurityGroup - { - // Server side filter - describeRequest := &ec2.DescribeSecurityGroupsInput{} - describeRequest.Filters = []*ec2.Filter{ - newEc2Filter("ip-permission.protocol", "tcp"), - } - response, err := c.ec2.DescribeSecurityGroups(describeRequest) - if err != nil { - return fmt.Errorf("Error querying security groups for NLB: %q", err) - } - for _, sg := range response { - if !c.tagging.hasClusterTag(sg.Tags) { - continue - } - matchingGroups = append(matchingGroups, sg) - } - - // client-side filter out groups that don't have IP Rules we've - // annotated for this service - matchingGroups = filterForIPRangeDescription(matchingGroups, loadBalancerName) - } - - { - clientRule := fmt.Sprintf("%s=%s", NLBClientRuleDescription, loadBalancerName) - mtuRule := fmt.Sprintf("%s=%s", NLBMtuDiscoveryRuleDescription, loadBalancerName) - healthRule := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, loadBalancerName) - - for i := range matchingGroups { - removes := []*ec2.IpPermission{} - for j := range matchingGroups[i].IpPermissions { - - v4rangesToRemove := []*ec2.IpRange{} - v6rangesToRemove := []*ec2.Ipv6Range{} - - // Find IpPermission that contains k8s description - // If we removed the whole IpPermission, it could contain other non-k8s specified ranges - for k := range matchingGroups[i].IpPermissions[j].IpRanges { - description := aws.StringValue(matchingGroups[i].IpPermissions[j].IpRanges[k].Description) - if description == clientRule || description == mtuRule || description == healthRule { - v4rangesToRemove = append(v4rangesToRemove, matchingGroups[i].IpPermissions[j].IpRanges[k]) - } - } - - // Find IpPermission that contains k8s description - // If we removed the whole IpPermission, it could contain other non-k8s specified rangesk - for k := range matchingGroups[i].IpPermissions[j].Ipv6Ranges { - description := aws.StringValue(matchingGroups[i].IpPermissions[j].Ipv6Ranges[k].Description) - if description == clientRule || description == mtuRule || description == healthRule { - v6rangesToRemove = append(v6rangesToRemove, matchingGroups[i].IpPermissions[j].Ipv6Ranges[k]) - } - } - - // ipv4 and ipv6 removals cannot be included in the same permission - if len(v4rangesToRemove) > 0 { - // create a new *IpPermission to not accidentally remove UserIdGroupPairs - removedPermission := &ec2.IpPermission{ - FromPort: matchingGroups[i].IpPermissions[j].FromPort, - IpProtocol: matchingGroups[i].IpPermissions[j].IpProtocol, - IpRanges: v4rangesToRemove, - ToPort: matchingGroups[i].IpPermissions[j].ToPort, - } - removes = append(removes, removedPermission) - } - if len(v6rangesToRemove) > 0 { - // create a new *IpPermission to not accidentally remove UserIdGroupPairs - removedPermission := &ec2.IpPermission{ - FromPort: matchingGroups[i].IpPermissions[j].FromPort, - IpProtocol: matchingGroups[i].IpPermissions[j].IpProtocol, - Ipv6Ranges: v6rangesToRemove, - ToPort: matchingGroups[i].IpPermissions[j].ToPort, - } - removes = append(removes, removedPermission) - } - - } - if len(removes) > 0 { - changed, err := c.removeSecurityGroupIngress(aws.StringValue(matchingGroups[i].GroupId), removes) - if err != nil { - return err - } - if !changed { - klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", *matchingGroups[i].GroupId) - } - } - - } - - } - - } - return nil + return c.updateInstanceSecurityGroupsForNLB(loadBalancerName, nil, nil, nil) } lb, err := c.describeLoadBalancer(loadBalancerName) diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go index 3d6b50840d7..ca72cad4b4b 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer.go @@ -68,7 +68,6 @@ type nlbPortMapping struct { TrafficPort int64 TrafficProtocol string - ClientCIDR string HealthCheckPort int64 HealthCheckPath string @@ -648,50 +647,6 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty return targetGroup, nil } -func portsForNLB(lbName string, sg *ec2.SecurityGroup, clientTraffic bool) sets.Int64 { - response := sets.NewInt64() - var annotation string - if clientTraffic { - annotation = fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName) - } else { - annotation = fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName) - } - - for i := range sg.IpPermissions { - for j := range sg.IpPermissions[i].IpRanges { - description := aws.StringValue(sg.IpPermissions[i].IpRanges[j].Description) - if description == annotation { - // TODO should probably check FromPort == ToPort - response.Insert(aws.Int64Value(sg.IpPermissions[i].FromPort)) - } - } - } - return response -} - -// filterForIPRangeDescription filters in security groups that have IpRange Descriptions that match a loadBalancerName -func filterForIPRangeDescription(securityGroups []*ec2.SecurityGroup, lbName string) []*ec2.SecurityGroup { - response := []*ec2.SecurityGroup{} - clientRule := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName) - healthRule := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName) - alreadyAdded := sets.NewString() - for i := range securityGroups { - for j := range securityGroups[i].IpPermissions { - for k := range securityGroups[i].IpPermissions[j].IpRanges { - description := aws.StringValue(securityGroups[i].IpPermissions[j].IpRanges[k].Description) - if description == clientRule || description == healthRule { - sgIDString := aws.StringValue(securityGroups[i].GroupId) - if !alreadyAdded.Has(sgIDString) { - response = append(response, securityGroups[i]) - alreadyAdded.Insert(sgIDString) - } - } - } - } - } - return response -} - func (c *Cloud) getVpcCidrBlocks() ([]string, error) { vpcs, err := c.ec2.DescribeVpcs(&ec2.DescribeVpcsInput{ VpcIds: []*string{aws.String(c.vpcID)}, @@ -710,203 +665,76 @@ func (c *Cloud) getVpcCidrBlocks() ([]string, error) { return cidrBlocks, nil } -// abstraction for updating SG rules -// if clientTraffic is false, then only update HealthCheck rules -func (c *Cloud) updateInstanceSecurityGroupsForNLBTraffic(actualGroups []*ec2.SecurityGroup, desiredSgIds []string, ports []int64, lbName string, clientCidrs []string, clientTraffic bool) error { - - klog.V(8).Infof("updateInstanceSecurityGroupsForNLBTraffic: actualGroups=%v, desiredSgIds=%v, ports=%v, clientTraffic=%v", actualGroups, desiredSgIds, ports, clientTraffic) - // Map containing the groups we want to make changes on; the ports to make - // changes on; and whether to add or remove it. true to add, false to remove - portChanges := map[string]map[int64]bool{} - - for _, id := range desiredSgIds { - // consider everything an addition for now - if _, ok := portChanges[id]; !ok { - portChanges[id] = make(map[int64]bool) - } - for _, port := range ports { - portChanges[id][port] = true - } +// updateInstanceSecurityGroupsForNLB will adjust securityGroup's settings to allow inbound traffic into instances from clientCIDRs and portMappings. +// TIP: if either instances or clientCIDRs or portMappings are nil, then the securityGroup rules for lbName are cleared. +func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[InstanceID]*ec2.Instance, clientCIDRs []string, portMappings []nlbPortMapping) error { + if c.cfg.Global.DisableSecurityGroupIngress { + return nil } - // Compare to actual groups - for _, actualGroup := range actualGroups { - actualGroupID := aws.StringValue(actualGroup.GroupId) - if actualGroupID == "" { - klog.Warning("Ignoring group without ID: ", actualGroup) + clusterSGs, err := c.getTaggedSecurityGroups() + if err != nil { + return fmt.Errorf("error querying for tagged security groups: %q", err) + } + // scan instances for groups we want to open + desiredSGIDs := sets.String{} + for _, instance := range instances { + sg, err := findSecurityGroupForInstance(instance, clusterSGs) + if err != nil { + return err + } + if sg == nil { + klog.Warningf("Ignoring instance without security group: %s", aws.StringValue(instance.InstanceId)) continue } + desiredSGIDs.Insert(aws.StringValue(sg.GroupId)) + } - addingMap, ok := portChanges[actualGroupID] - if ok { - desiredSet := sets.NewInt64() - for port := range addingMap { - desiredSet.Insert(port) - } - existingSet := portsForNLB(lbName, actualGroup, clientTraffic) - - // remove from portChanges ports that are already allowed - if intersection := desiredSet.Intersection(existingSet); intersection.Len() > 0 { - for p := range intersection { - delete(portChanges[actualGroupID], p) - } - } - - // allowed ports that need to be removed - if difference := existingSet.Difference(desiredSet); difference.Len() > 0 { - for p := range difference { - portChanges[actualGroupID][p] = false - } + // TODO(@M00nF1sh): do we really needs to support SG without cluster tag at current version? + // findSecurityGroupForInstance might return SG that are not tagged. + { + for sgID := range desiredSGIDs.Difference(sets.StringKeySet(clusterSGs)) { + sg, err := c.findSecurityGroup(sgID) + if err != nil { + return fmt.Errorf("error finding instance group: %q", err) } + clusterSGs[sgID] = sg } } - // Make changes we've planned on - for instanceSecurityGroupID, portMap := range portChanges { - adds := []*ec2.IpPermission{} - removes := []*ec2.IpPermission{} - for port, add := range portMap { - if add { - if clientTraffic { - klog.V(2).Infof("Adding rule for client MTU discovery from the network load balancer (%s) to instances (%s)", clientCidrs, instanceSecurityGroupID) - klog.V(2).Infof("Adding rule for client traffic from the network load balancer (%s) to instances (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port) - } else { - klog.V(2).Infof("Adding rule for health check traffic from the network load balancer (%s) to instances (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port) - } - } else { - if clientTraffic { - klog.V(2).Infof("Removing rule for client MTU discovery from the network load balancer (%s) to instances (%s)", clientCidrs, instanceSecurityGroupID) - klog.V(2).Infof("Removing rule for client traffic from the network load balancer (%s) to instance (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port) - } - klog.V(2).Infof("Removing rule for health check traffic from the network load balancer (%s) to instance (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port) - } - - if clientTraffic { - clientRuleAnnotation := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName) - // Client Traffic - permission := &ec2.IpPermission{ - FromPort: aws.Int64(port), - ToPort: aws.Int64(port), - IpProtocol: aws.String("tcp"), - } - ranges := []*ec2.IpRange{} - for _, cidr := range clientCidrs { - ranges = append(ranges, &ec2.IpRange{ - CidrIp: aws.String(cidr), - Description: aws.String(clientRuleAnnotation), - }) - } - permission.IpRanges = ranges - if add { - adds = append(adds, permission) - } else { - removes = append(removes, permission) - } - } else { - healthRuleAnnotation := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName) - - // NLB HealthCheck - permission := &ec2.IpPermission{ - FromPort: aws.Int64(port), - ToPort: aws.Int64(port), - IpProtocol: aws.String("tcp"), - } - ranges := []*ec2.IpRange{} - for _, cidr := range clientCidrs { - ranges = append(ranges, &ec2.IpRange{ - CidrIp: aws.String(cidr), - Description: aws.String(healthRuleAnnotation), - }) - } - permission.IpRanges = ranges - if add { - adds = append(adds, permission) - } else { - removes = append(removes, permission) - } - } + { + clientPorts := sets.Int64{} + healthCheckPorts := sets.Int64{} + for _, port := range portMappings { + clientPorts.Insert(port.TrafficPort) + healthCheckPorts.Insert(port.HealthCheckPort) } - - if len(adds) > 0 { - changed, err := c.addSecurityGroupIngress(instanceSecurityGroupID, adds) - if err != nil { - return err - } - if !changed { - klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID) - } + clientRuleAnnotation := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName) + healthRuleAnnotation := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName) + vpcCIDRs, err := c.getVpcCidrBlocks() + if err != nil { + return err } - - if len(removes) > 0 { - changed, err := c.removeSecurityGroupIngress(instanceSecurityGroupID, removes) - if err != nil { - return err - } - if !changed { - klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID) - } - } - - if clientTraffic { - // MTU discovery - mtuRuleAnnotation := fmt.Sprintf("%s=%s", NLBMtuDiscoveryRuleDescription, lbName) - mtuPermission := &ec2.IpPermission{ - IpProtocol: aws.String("icmp"), - FromPort: aws.Int64(3), - ToPort: aws.Int64(4), - } - ranges := []*ec2.IpRange{} - for _, cidr := range clientCidrs { - ranges = append(ranges, &ec2.IpRange{ - CidrIp: aws.String(cidr), - Description: aws.String(mtuRuleAnnotation), - }) - } - mtuPermission.IpRanges = ranges - - group, err := c.findSecurityGroup(instanceSecurityGroupID) - if err != nil { - klog.Warningf("Error retrieving security group: %q", err) - return err - } - - if group == nil { - klog.Warning("Security group not found: ", instanceSecurityGroupID) - return nil - } - - icmpExists := false - permCount := 0 - for _, perm := range group.IpPermissions { - if *perm.IpProtocol == "icmp" { - icmpExists = true - continue - } - - if perm.FromPort != nil { - permCount++ - } - } - - if !icmpExists && permCount > 0 { - // the icmp permission is missing - changed, err := c.addSecurityGroupIngress(instanceSecurityGroupID, []*ec2.IpPermission{mtuPermission}) - if err != nil { - klog.Warningf("Error adding MTU permission to security group: %q", err) + for sgID, sg := range clusterSGs { + sgPerms := NewIPPermissionSet(sg.IpPermissions...).Ungroup() + if desiredSGIDs.Has(sgID) { + if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, vpcCIDRs); err != nil { return err } - if !changed { - klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID) - } - } else if icmpExists && permCount == 0 { - // there is no additional permissions, remove icmp - changed, err := c.removeSecurityGroupIngress(instanceSecurityGroupID, []*ec2.IpPermission{mtuPermission}) - if err != nil { - klog.Warningf("Error removing MTU permission to security group: %q", err) + if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, "tcp", clientPorts, clientCIDRs); err != nil { return err } - if !changed { - klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID) + } else { + if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", nil, nil); err != nil { + return err + } + if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, "tcp", nil, nil); err != nil { + return err + } + } + if !sgPerms.Equal(NewIPPermissionSet(sg.IpPermissions...).Ungroup()) { + if err := c.updateInstanceSecurityGroupForNLBMTU(sgID, sgPerms); err != nil { + return err } } } @@ -914,102 +742,105 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLBTraffic(actualGroups []*ec2.Se return nil } -// Add SG rules for a given NLB -func (c *Cloud) updateInstanceSecurityGroupsForNLB(mappings []nlbPortMapping, instances map[InstanceID]*ec2.Instance, lbName string, clientCidrs []string) error { - if c.cfg.Global.DisableSecurityGroupIngress { - return nil - } - - vpcCidrBlocks, err := c.getVpcCidrBlocks() - if err != nil { - return err - } - - // Unlike the classic ELB, NLB does not have a security group that we can - // filter against all existing groups to see if they allow access. Instead - // we use the IpRange.Description field to annotate NLB health check and - // client traffic rules - - // Get the actual list of groups that allow ingress for the load-balancer - var actualGroups []*ec2.SecurityGroup - { - // Server side filter - describeRequest := &ec2.DescribeSecurityGroupsInput{} - describeRequest.Filters = []*ec2.Filter{ - newEc2Filter("ip-permission.protocol", "tcp"), - newEc2Filter("vpc-id", c.vpcID), +// updateInstanceSecurityGroupForNLBTraffic will manage permissions set(identified by ruleDesc) on securityGroup to match desired set(allow protocol traffic from ports/cidr). +// Note: sgPerms will be updated to reflect the current permission set on SG after update. +func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IPPermissionSet, ruleDesc string, protocol string, ports sets.Int64, cidrs []string) error { + desiredPerms := NewIPPermissionSet() + for port := range ports { + for _, cidr := range cidrs { + desiredPerms.Insert(&ec2.IpPermission{ + IpProtocol: aws.String(protocol), + FromPort: aws.Int64(port), + ToPort: aws.Int64(port), + IpRanges: []*ec2.IpRange{ + { + CidrIp: aws.String(cidr), + Description: aws.String(ruleDesc), + }, + }, + }) } - response, err := c.ec2.DescribeSecurityGroups(describeRequest) - if err != nil { - return fmt.Errorf("Error querying security groups for NLB: %q", err) - } - for _, sg := range response { - if !c.tagging.hasClusterTag(sg.Tags) { - continue - } - actualGroups = append(actualGroups, sg) - } - - // client-side filter - // Filter out groups that don't have IP Rules we've annotated for this service - actualGroups = filterForIPRangeDescription(actualGroups, lbName) - } - - taggedSecurityGroups, err := c.getTaggedSecurityGroups() - if err != nil { - return fmt.Errorf("Error querying for tagged security groups: %q", err) - } - - externalTrafficPolicyIsLocal := false - trafficPorts := []int64{} - for i := range mappings { - trafficPorts = append(trafficPorts, mappings[i].TrafficPort) - if mappings[i].TrafficPort != mappings[i].HealthCheckPort { - externalTrafficPolicyIsLocal = true - } - } - - healthCheckPorts := trafficPorts - // if externalTrafficPolicy is Local, all listeners use the same health - // check port - if externalTrafficPolicyIsLocal && len(mappings) > 0 { - healthCheckPorts = []int64{mappings[0].HealthCheckPort} - } - - desiredGroupIds := []string{} - // Scan instances for groups we want open - for _, instance := range instances { - securityGroup, err := findSecurityGroupForInstance(instance, taggedSecurityGroups) + } + + permsToGrant := desiredPerms.Difference(sgPerms) + permsToRevoke := sgPerms.Difference(desiredPerms) + permsToRevoke.DeleteIf(IPPermissionNotMatch{IPPermissionMatchDesc{ruleDesc}}) + if len(permsToRevoke) > 0 { + permsToRevokeList := permsToRevoke.List() + changed, err := c.removeSecurityGroupIngress(sgID, permsToRevokeList) if err != nil { + klog.Warningf("Error remove traffic permission from security group: %q", err) return err } + if !changed { + klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", sgID) + } + sgPerms.Delete(permsToRevokeList...) + } + if len(permsToGrant) > 0 { + permsToGrantList := permsToGrant.List() + changed, err := c.addSecurityGroupIngress(sgID, permsToGrantList) + if err != nil { + klog.Warningf("Error add traffic permission to security group: %q", err) + return err + } + if !changed { + klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", sgID) + } + sgPerms.Insert(permsToGrantList...) + } + return nil +} - if securityGroup == nil { - klog.Warningf("Ignoring instance without security group: %s", aws.StringValue(instance.InstanceId)) - continue +// Note: sgPerms will be updated to reflect the current permission set on SG after update. +func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPermissionSet) error { + desiredPerms := NewIPPermissionSet() + for _, perm := range sgPerms { + for _, ipRange := range perm.IpRanges { + if strings.Contains(aws.StringValue(ipRange.Description), NLBClientRuleDescription) { + desiredPerms.Insert(&ec2.IpPermission{ + IpProtocol: aws.String("icmp"), + FromPort: aws.Int64(3), + ToPort: aws.Int64(4), + IpRanges: []*ec2.IpRange{ + { + CidrIp: ipRange.CidrIp, + Description: aws.String(NLBMtuDiscoveryRuleDescription), + }, + }, + }) + } + } + } + + permsToGrant := desiredPerms.Difference(sgPerms) + permsToRevoke := sgPerms.Difference(desiredPerms) + permsToRevoke.DeleteIf(IPPermissionNotMatch{IPPermissionMatchDesc{NLBMtuDiscoveryRuleDescription}}) + if len(permsToRevoke) > 0 { + permsToRevokeList := permsToRevoke.List() + changed, err := c.removeSecurityGroupIngress(sgID, permsToRevokeList) + if err != nil { + klog.Warningf("Error remove MTU permission from security group: %q", err) + return err + } + if !changed { + klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", sgID) } - id := aws.StringValue(securityGroup.GroupId) - if id == "" { - klog.Warningf("found security group without id: %v", securityGroup) - continue + sgPerms.Delete(permsToRevokeList...) + } + if len(permsToGrant) > 0 { + permsToGrantList := permsToGrant.List() + changed, err := c.addSecurityGroupIngress(sgID, permsToGrantList) + if err != nil { + klog.Warningf("Error add MTU permission to security group: %q", err) + return err } - - desiredGroupIds = append(desiredGroupIds, id) + if !changed { + klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", sgID) + } + sgPerms.Insert(permsToGrantList...) } - - // Run once for Client traffic - err = c.updateInstanceSecurityGroupsForNLBTraffic(actualGroups, desiredGroupIds, trafficPorts, lbName, clientCidrs, true) - if err != nil { - return err - } - - // Run once for health check traffic - err = c.updateInstanceSecurityGroupsForNLBTraffic(actualGroups, desiredGroupIds, healthCheckPorts, lbName, vpcCidrBlocks, false) - if err != nil { - return err - } - return nil } diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer_test.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer_test.go index 0dc09bb7c97..0c4214af5d3 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_loadbalancer_test.go @@ -17,11 +17,9 @@ limitations under the License. package aws import ( - "fmt" "testing" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/elb" "github.com/stretchr/testify/assert" ) @@ -165,66 +163,6 @@ func TestIsNLB(t *testing.T) { } } -func TestSecurityGroupFiltering(t *testing.T) { - grid := []struct { - in []*ec2.SecurityGroup - name string - expected int - description string - }{ - { - in: []*ec2.SecurityGroup{ - { - IpPermissions: []*ec2.IpPermission{ - { - IpRanges: []*ec2.IpRange{ - { - Description: aws.String("an unmanaged"), - }, - }, - }, - }, - }, - }, - name: "unmanaged", - expected: 0, - description: "An environment without managed LBs should have %d, but found %d SecurityGroups", - }, - { - in: []*ec2.SecurityGroup{ - { - IpPermissions: []*ec2.IpPermission{ - { - IpRanges: []*ec2.IpRange{ - { - Description: aws.String("an unmanaged"), - }, - { - Description: aws.String(fmt.Sprintf("%s=%s", NLBClientRuleDescription, "managedlb")), - }, - { - Description: aws.String(fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, "managedlb")), - }, - }, - }, - }, - }, - }, - name: "managedlb", - expected: 1, - description: "Found %d, but should have %d Security Groups", - }, - } - - for _, g := range grid { - actual := len(filterForIPRangeDescription(g.in, g.name)) - if actual != g.expected { - t.Errorf(g.description, actual, g.expected) - } - } - -} - func TestSyncElbListeners(t *testing.T) { tests := []struct { name string diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/sets_ippermissions.go b/staging/src/k8s.io/legacy-cloud-providers/aws/sets_ippermissions.go index d71948f8d22..201b8ac9cf0 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/sets_ippermissions.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/sets_ippermissions.go @@ -20,12 +20,19 @@ import ( "encoding/json" "fmt" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" ) // IPPermissionSet maps IP strings of strings to EC2 IpPermissions type IPPermissionSet map[string]*ec2.IpPermission +// IPPermissionPredicate is an predicate to test whether IPPermission matches some condition. +type IPPermissionPredicate interface { + // Test checks whether specified IPPermission matches condition. + Test(perm *ec2.IpPermission) bool +} + // NewIPPermissionSet creates a new IPPermissionSet func NewIPPermissionSet(items ...*ec2.IpPermission) IPPermissionSet { s := make(IPPermissionSet) @@ -90,6 +97,23 @@ func (s IPPermissionSet) Insert(items ...*ec2.IpPermission) { } } +// Delete delete permission from the set. +func (s IPPermissionSet) Delete(items ...*ec2.IpPermission) { + for _, p := range items { + k := keyForIPPermission(p) + delete(s, k) + } +} + +// DeleteIf delete permission from the set if permission matches predicate. +func (s IPPermissionSet) DeleteIf(predicate IPPermissionPredicate) { + for k, p := range s { + if predicate.Test(p) { + delete(s, k) + } + } +} + // List returns the contents as a slice. Order is not defined. func (s IPPermissionSet) List() []*ec2.IpPermission { res := make([]*ec2.IpPermission, 0, len(s)) @@ -146,3 +170,47 @@ func keyForIPPermission(p *ec2.IpPermission) string { } return string(v) } + +var _ IPPermissionPredicate = IPPermissionMatchDesc{} + +// IPPermissionMatchDesc checks whether specific IPPermission contains description. +type IPPermissionMatchDesc struct { + Description string +} + +// Test whether specific IPPermission contains description. +func (p IPPermissionMatchDesc) Test(perm *ec2.IpPermission) bool { + for _, v4Range := range perm.IpRanges { + if aws.StringValue(v4Range.Description) == p.Description { + return true + } + } + for _, v6Range := range perm.Ipv6Ranges { + if aws.StringValue(v6Range.Description) == p.Description { + return true + } + } + for _, prefixListID := range perm.PrefixListIds { + if aws.StringValue(prefixListID.Description) == p.Description { + return true + } + } + for _, group := range perm.UserIdGroupPairs { + if aws.StringValue(group.Description) == p.Description { + return true + } + } + return false +} + +var _ IPPermissionPredicate = IPPermissionNotMatch{} + +// IPPermissionNotMatch is the *not* operator for Predicate +type IPPermissionNotMatch struct { + Predicate IPPermissionPredicate +} + +// Test whether specific IPPermission not match the embed predicate. +func (p IPPermissionNotMatch) Test(perm *ec2.IpPermission) bool { + return !p.Predicate.Test(perm) +}