Merge pull request #74692 from M00nF1sh/fix_sg

refactor NLB securityGroup handling
This commit is contained in:
Kubernetes Prow Robot 2019-05-29 22:54:18 -07:00 committed by GitHub
commit 009f7a07ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 215 additions and 470 deletions

View File

@ -3579,7 +3579,7 @@ func (c *Cloud) EnsureLoadBalancer(ctx context.Context, clusterName string, apiS
sourceRangeCidrs = append(sourceRangeCidrs, "0.0.0.0/0")
}
err = c.updateInstanceSecurityGroupsForNLB(v2Mappings, instances, loadBalancerName, sourceRangeCidrs)
err = c.updateInstanceSecurityGroupsForNLB(loadBalancerName, instances, sourceRangeCidrs, v2Mappings)
if err != nil {
klog.Warningf("Error opening ingress rules for the load balancer to the instances: %q", err)
return nil, err
@ -4158,99 +4158,7 @@ func (c *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName strin
}
}
{
var matchingGroups []*ec2.SecurityGroup
{
// Server side filter
describeRequest := &ec2.DescribeSecurityGroupsInput{}
describeRequest.Filters = []*ec2.Filter{
newEc2Filter("ip-permission.protocol", "tcp"),
}
response, err := c.ec2.DescribeSecurityGroups(describeRequest)
if err != nil {
return fmt.Errorf("Error querying security groups for NLB: %q", err)
}
for _, sg := range response {
if !c.tagging.hasClusterTag(sg.Tags) {
continue
}
matchingGroups = append(matchingGroups, sg)
}
// client-side filter out groups that don't have IP Rules we've
// annotated for this service
matchingGroups = filterForIPRangeDescription(matchingGroups, loadBalancerName)
}
{
clientRule := fmt.Sprintf("%s=%s", NLBClientRuleDescription, loadBalancerName)
mtuRule := fmt.Sprintf("%s=%s", NLBMtuDiscoveryRuleDescription, loadBalancerName)
healthRule := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, loadBalancerName)
for i := range matchingGroups {
removes := []*ec2.IpPermission{}
for j := range matchingGroups[i].IpPermissions {
v4rangesToRemove := []*ec2.IpRange{}
v6rangesToRemove := []*ec2.Ipv6Range{}
// Find IpPermission that contains k8s description
// If we removed the whole IpPermission, it could contain other non-k8s specified ranges
for k := range matchingGroups[i].IpPermissions[j].IpRanges {
description := aws.StringValue(matchingGroups[i].IpPermissions[j].IpRanges[k].Description)
if description == clientRule || description == mtuRule || description == healthRule {
v4rangesToRemove = append(v4rangesToRemove, matchingGroups[i].IpPermissions[j].IpRanges[k])
}
}
// Find IpPermission that contains k8s description
// If we removed the whole IpPermission, it could contain other non-k8s specified rangesk
for k := range matchingGroups[i].IpPermissions[j].Ipv6Ranges {
description := aws.StringValue(matchingGroups[i].IpPermissions[j].Ipv6Ranges[k].Description)
if description == clientRule || description == mtuRule || description == healthRule {
v6rangesToRemove = append(v6rangesToRemove, matchingGroups[i].IpPermissions[j].Ipv6Ranges[k])
}
}
// ipv4 and ipv6 removals cannot be included in the same permission
if len(v4rangesToRemove) > 0 {
// create a new *IpPermission to not accidentally remove UserIdGroupPairs
removedPermission := &ec2.IpPermission{
FromPort: matchingGroups[i].IpPermissions[j].FromPort,
IpProtocol: matchingGroups[i].IpPermissions[j].IpProtocol,
IpRanges: v4rangesToRemove,
ToPort: matchingGroups[i].IpPermissions[j].ToPort,
}
removes = append(removes, removedPermission)
}
if len(v6rangesToRemove) > 0 {
// create a new *IpPermission to not accidentally remove UserIdGroupPairs
removedPermission := &ec2.IpPermission{
FromPort: matchingGroups[i].IpPermissions[j].FromPort,
IpProtocol: matchingGroups[i].IpPermissions[j].IpProtocol,
Ipv6Ranges: v6rangesToRemove,
ToPort: matchingGroups[i].IpPermissions[j].ToPort,
}
removes = append(removes, removedPermission)
}
}
if len(removes) > 0 {
changed, err := c.removeSecurityGroupIngress(aws.StringValue(matchingGroups[i].GroupId), removes)
if err != nil {
return err
}
if !changed {
klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", *matchingGroups[i].GroupId)
}
}
}
}
}
return nil
return c.updateInstanceSecurityGroupsForNLB(loadBalancerName, nil, nil, nil)
}
lb, err := c.describeLoadBalancer(loadBalancerName)

View File

@ -68,7 +68,6 @@ type nlbPortMapping struct {
TrafficPort int64
TrafficProtocol string
ClientCIDR string
HealthCheckPort int64
HealthCheckPath string
@ -648,50 +647,6 @@ func (c *Cloud) ensureTargetGroup(targetGroup *elbv2.TargetGroup, serviceName ty
return targetGroup, nil
}
func portsForNLB(lbName string, sg *ec2.SecurityGroup, clientTraffic bool) sets.Int64 {
response := sets.NewInt64()
var annotation string
if clientTraffic {
annotation = fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName)
} else {
annotation = fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName)
}
for i := range sg.IpPermissions {
for j := range sg.IpPermissions[i].IpRanges {
description := aws.StringValue(sg.IpPermissions[i].IpRanges[j].Description)
if description == annotation {
// TODO should probably check FromPort == ToPort
response.Insert(aws.Int64Value(sg.IpPermissions[i].FromPort))
}
}
}
return response
}
// filterForIPRangeDescription filters in security groups that have IpRange Descriptions that match a loadBalancerName
func filterForIPRangeDescription(securityGroups []*ec2.SecurityGroup, lbName string) []*ec2.SecurityGroup {
response := []*ec2.SecurityGroup{}
clientRule := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName)
healthRule := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName)
alreadyAdded := sets.NewString()
for i := range securityGroups {
for j := range securityGroups[i].IpPermissions {
for k := range securityGroups[i].IpPermissions[j].IpRanges {
description := aws.StringValue(securityGroups[i].IpPermissions[j].IpRanges[k].Description)
if description == clientRule || description == healthRule {
sgIDString := aws.StringValue(securityGroups[i].GroupId)
if !alreadyAdded.Has(sgIDString) {
response = append(response, securityGroups[i])
alreadyAdded.Insert(sgIDString)
}
}
}
}
}
return response
}
func (c *Cloud) getVpcCidrBlocks() ([]string, error) {
vpcs, err := c.ec2.DescribeVpcs(&ec2.DescribeVpcsInput{
VpcIds: []*string{aws.String(c.vpcID)},
@ -710,203 +665,76 @@ func (c *Cloud) getVpcCidrBlocks() ([]string, error) {
return cidrBlocks, nil
}
// abstraction for updating SG rules
// if clientTraffic is false, then only update HealthCheck rules
func (c *Cloud) updateInstanceSecurityGroupsForNLBTraffic(actualGroups []*ec2.SecurityGroup, desiredSgIds []string, ports []int64, lbName string, clientCidrs []string, clientTraffic bool) error {
klog.V(8).Infof("updateInstanceSecurityGroupsForNLBTraffic: actualGroups=%v, desiredSgIds=%v, ports=%v, clientTraffic=%v", actualGroups, desiredSgIds, ports, clientTraffic)
// Map containing the groups we want to make changes on; the ports to make
// changes on; and whether to add or remove it. true to add, false to remove
portChanges := map[string]map[int64]bool{}
for _, id := range desiredSgIds {
// consider everything an addition for now
if _, ok := portChanges[id]; !ok {
portChanges[id] = make(map[int64]bool)
}
for _, port := range ports {
portChanges[id][port] = true
}
// updateInstanceSecurityGroupsForNLB will adjust securityGroup's settings to allow inbound traffic into instances from clientCIDRs and portMappings.
// TIP: if either instances or clientCIDRs or portMappings are nil, then the securityGroup rules for lbName are cleared.
func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[InstanceID]*ec2.Instance, clientCIDRs []string, portMappings []nlbPortMapping) error {
if c.cfg.Global.DisableSecurityGroupIngress {
return nil
}
// Compare to actual groups
for _, actualGroup := range actualGroups {
actualGroupID := aws.StringValue(actualGroup.GroupId)
if actualGroupID == "" {
klog.Warning("Ignoring group without ID: ", actualGroup)
clusterSGs, err := c.getTaggedSecurityGroups()
if err != nil {
return fmt.Errorf("error querying for tagged security groups: %q", err)
}
// scan instances for groups we want to open
desiredSGIDs := sets.String{}
for _, instance := range instances {
sg, err := findSecurityGroupForInstance(instance, clusterSGs)
if err != nil {
return err
}
if sg == nil {
klog.Warningf("Ignoring instance without security group: %s", aws.StringValue(instance.InstanceId))
continue
}
desiredSGIDs.Insert(aws.StringValue(sg.GroupId))
}
addingMap, ok := portChanges[actualGroupID]
if ok {
desiredSet := sets.NewInt64()
for port := range addingMap {
desiredSet.Insert(port)
}
existingSet := portsForNLB(lbName, actualGroup, clientTraffic)
// remove from portChanges ports that are already allowed
if intersection := desiredSet.Intersection(existingSet); intersection.Len() > 0 {
for p := range intersection {
delete(portChanges[actualGroupID], p)
}
}
// allowed ports that need to be removed
if difference := existingSet.Difference(desiredSet); difference.Len() > 0 {
for p := range difference {
portChanges[actualGroupID][p] = false
}
// TODO(@M00nF1sh): do we really needs to support SG without cluster tag at current version?
// findSecurityGroupForInstance might return SG that are not tagged.
{
for sgID := range desiredSGIDs.Difference(sets.StringKeySet(clusterSGs)) {
sg, err := c.findSecurityGroup(sgID)
if err != nil {
return fmt.Errorf("error finding instance group: %q", err)
}
clusterSGs[sgID] = sg
}
}
// Make changes we've planned on
for instanceSecurityGroupID, portMap := range portChanges {
adds := []*ec2.IpPermission{}
removes := []*ec2.IpPermission{}
for port, add := range portMap {
if add {
if clientTraffic {
klog.V(2).Infof("Adding rule for client MTU discovery from the network load balancer (%s) to instances (%s)", clientCidrs, instanceSecurityGroupID)
klog.V(2).Infof("Adding rule for client traffic from the network load balancer (%s) to instances (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port)
} else {
klog.V(2).Infof("Adding rule for health check traffic from the network load balancer (%s) to instances (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port)
}
} else {
if clientTraffic {
klog.V(2).Infof("Removing rule for client MTU discovery from the network load balancer (%s) to instances (%s)", clientCidrs, instanceSecurityGroupID)
klog.V(2).Infof("Removing rule for client traffic from the network load balancer (%s) to instance (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port)
}
klog.V(2).Infof("Removing rule for health check traffic from the network load balancer (%s) to instance (%s), port (%v)", clientCidrs, instanceSecurityGroupID, port)
}
if clientTraffic {
clientRuleAnnotation := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName)
// Client Traffic
permission := &ec2.IpPermission{
FromPort: aws.Int64(port),
ToPort: aws.Int64(port),
IpProtocol: aws.String("tcp"),
}
ranges := []*ec2.IpRange{}
for _, cidr := range clientCidrs {
ranges = append(ranges, &ec2.IpRange{
CidrIp: aws.String(cidr),
Description: aws.String(clientRuleAnnotation),
})
}
permission.IpRanges = ranges
if add {
adds = append(adds, permission)
} else {
removes = append(removes, permission)
}
} else {
healthRuleAnnotation := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName)
// NLB HealthCheck
permission := &ec2.IpPermission{
FromPort: aws.Int64(port),
ToPort: aws.Int64(port),
IpProtocol: aws.String("tcp"),
}
ranges := []*ec2.IpRange{}
for _, cidr := range clientCidrs {
ranges = append(ranges, &ec2.IpRange{
CidrIp: aws.String(cidr),
Description: aws.String(healthRuleAnnotation),
})
}
permission.IpRanges = ranges
if add {
adds = append(adds, permission)
} else {
removes = append(removes, permission)
}
}
{
clientPorts := sets.Int64{}
healthCheckPorts := sets.Int64{}
for _, port := range portMappings {
clientPorts.Insert(port.TrafficPort)
healthCheckPorts.Insert(port.HealthCheckPort)
}
if len(adds) > 0 {
changed, err := c.addSecurityGroupIngress(instanceSecurityGroupID, adds)
if err != nil {
return err
}
if !changed {
klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID)
}
clientRuleAnnotation := fmt.Sprintf("%s=%s", NLBClientRuleDescription, lbName)
healthRuleAnnotation := fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, lbName)
vpcCIDRs, err := c.getVpcCidrBlocks()
if err != nil {
return err
}
if len(removes) > 0 {
changed, err := c.removeSecurityGroupIngress(instanceSecurityGroupID, removes)
if err != nil {
return err
}
if !changed {
klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID)
}
}
if clientTraffic {
// MTU discovery
mtuRuleAnnotation := fmt.Sprintf("%s=%s", NLBMtuDiscoveryRuleDescription, lbName)
mtuPermission := &ec2.IpPermission{
IpProtocol: aws.String("icmp"),
FromPort: aws.Int64(3),
ToPort: aws.Int64(4),
}
ranges := []*ec2.IpRange{}
for _, cidr := range clientCidrs {
ranges = append(ranges, &ec2.IpRange{
CidrIp: aws.String(cidr),
Description: aws.String(mtuRuleAnnotation),
})
}
mtuPermission.IpRanges = ranges
group, err := c.findSecurityGroup(instanceSecurityGroupID)
if err != nil {
klog.Warningf("Error retrieving security group: %q", err)
return err
}
if group == nil {
klog.Warning("Security group not found: ", instanceSecurityGroupID)
return nil
}
icmpExists := false
permCount := 0
for _, perm := range group.IpPermissions {
if *perm.IpProtocol == "icmp" {
icmpExists = true
continue
}
if perm.FromPort != nil {
permCount++
}
}
if !icmpExists && permCount > 0 {
// the icmp permission is missing
changed, err := c.addSecurityGroupIngress(instanceSecurityGroupID, []*ec2.IpPermission{mtuPermission})
if err != nil {
klog.Warningf("Error adding MTU permission to security group: %q", err)
for sgID, sg := range clusterSGs {
sgPerms := NewIPPermissionSet(sg.IpPermissions...).Ungroup()
if desiredSGIDs.Has(sgID) {
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", healthCheckPorts, vpcCIDRs); err != nil {
return err
}
if !changed {
klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID)
}
} else if icmpExists && permCount == 0 {
// there is no additional permissions, remove icmp
changed, err := c.removeSecurityGroupIngress(instanceSecurityGroupID, []*ec2.IpPermission{mtuPermission})
if err != nil {
klog.Warningf("Error removing MTU permission to security group: %q", err)
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, "tcp", clientPorts, clientCIDRs); err != nil {
return err
}
if !changed {
klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", instanceSecurityGroupID)
} else {
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, healthRuleAnnotation, "tcp", nil, nil); err != nil {
return err
}
if err := c.updateInstanceSecurityGroupForNLBTraffic(sgID, sgPerms, clientRuleAnnotation, "tcp", nil, nil); err != nil {
return err
}
}
if !sgPerms.Equal(NewIPPermissionSet(sg.IpPermissions...).Ungroup()) {
if err := c.updateInstanceSecurityGroupForNLBMTU(sgID, sgPerms); err != nil {
return err
}
}
}
@ -914,102 +742,105 @@ func (c *Cloud) updateInstanceSecurityGroupsForNLBTraffic(actualGroups []*ec2.Se
return nil
}
// Add SG rules for a given NLB
func (c *Cloud) updateInstanceSecurityGroupsForNLB(mappings []nlbPortMapping, instances map[InstanceID]*ec2.Instance, lbName string, clientCidrs []string) error {
if c.cfg.Global.DisableSecurityGroupIngress {
return nil
}
vpcCidrBlocks, err := c.getVpcCidrBlocks()
if err != nil {
return err
}
// Unlike the classic ELB, NLB does not have a security group that we can
// filter against all existing groups to see if they allow access. Instead
// we use the IpRange.Description field to annotate NLB health check and
// client traffic rules
// Get the actual list of groups that allow ingress for the load-balancer
var actualGroups []*ec2.SecurityGroup
{
// Server side filter
describeRequest := &ec2.DescribeSecurityGroupsInput{}
describeRequest.Filters = []*ec2.Filter{
newEc2Filter("ip-permission.protocol", "tcp"),
newEc2Filter("vpc-id", c.vpcID),
// updateInstanceSecurityGroupForNLBTraffic will manage permissions set(identified by ruleDesc) on securityGroup to match desired set(allow protocol traffic from ports/cidr).
// Note: sgPerms will be updated to reflect the current permission set on SG after update.
func (c *Cloud) updateInstanceSecurityGroupForNLBTraffic(sgID string, sgPerms IPPermissionSet, ruleDesc string, protocol string, ports sets.Int64, cidrs []string) error {
desiredPerms := NewIPPermissionSet()
for port := range ports {
for _, cidr := range cidrs {
desiredPerms.Insert(&ec2.IpPermission{
IpProtocol: aws.String(protocol),
FromPort: aws.Int64(port),
ToPort: aws.Int64(port),
IpRanges: []*ec2.IpRange{
{
CidrIp: aws.String(cidr),
Description: aws.String(ruleDesc),
},
},
})
}
response, err := c.ec2.DescribeSecurityGroups(describeRequest)
if err != nil {
return fmt.Errorf("Error querying security groups for NLB: %q", err)
}
for _, sg := range response {
if !c.tagging.hasClusterTag(sg.Tags) {
continue
}
actualGroups = append(actualGroups, sg)
}
// client-side filter
// Filter out groups that don't have IP Rules we've annotated for this service
actualGroups = filterForIPRangeDescription(actualGroups, lbName)
}
taggedSecurityGroups, err := c.getTaggedSecurityGroups()
if err != nil {
return fmt.Errorf("Error querying for tagged security groups: %q", err)
}
externalTrafficPolicyIsLocal := false
trafficPorts := []int64{}
for i := range mappings {
trafficPorts = append(trafficPorts, mappings[i].TrafficPort)
if mappings[i].TrafficPort != mappings[i].HealthCheckPort {
externalTrafficPolicyIsLocal = true
}
}
healthCheckPorts := trafficPorts
// if externalTrafficPolicy is Local, all listeners use the same health
// check port
if externalTrafficPolicyIsLocal && len(mappings) > 0 {
healthCheckPorts = []int64{mappings[0].HealthCheckPort}
}
desiredGroupIds := []string{}
// Scan instances for groups we want open
for _, instance := range instances {
securityGroup, err := findSecurityGroupForInstance(instance, taggedSecurityGroups)
}
permsToGrant := desiredPerms.Difference(sgPerms)
permsToRevoke := sgPerms.Difference(desiredPerms)
permsToRevoke.DeleteIf(IPPermissionNotMatch{IPPermissionMatchDesc{ruleDesc}})
if len(permsToRevoke) > 0 {
permsToRevokeList := permsToRevoke.List()
changed, err := c.removeSecurityGroupIngress(sgID, permsToRevokeList)
if err != nil {
klog.Warningf("Error remove traffic permission from security group: %q", err)
return err
}
if !changed {
klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", sgID)
}
sgPerms.Delete(permsToRevokeList...)
}
if len(permsToGrant) > 0 {
permsToGrantList := permsToGrant.List()
changed, err := c.addSecurityGroupIngress(sgID, permsToGrantList)
if err != nil {
klog.Warningf("Error add traffic permission to security group: %q", err)
return err
}
if !changed {
klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", sgID)
}
sgPerms.Insert(permsToGrantList...)
}
return nil
}
if securityGroup == nil {
klog.Warningf("Ignoring instance without security group: %s", aws.StringValue(instance.InstanceId))
continue
// Note: sgPerms will be updated to reflect the current permission set on SG after update.
func (c *Cloud) updateInstanceSecurityGroupForNLBMTU(sgID string, sgPerms IPPermissionSet) error {
desiredPerms := NewIPPermissionSet()
for _, perm := range sgPerms {
for _, ipRange := range perm.IpRanges {
if strings.Contains(aws.StringValue(ipRange.Description), NLBClientRuleDescription) {
desiredPerms.Insert(&ec2.IpPermission{
IpProtocol: aws.String("icmp"),
FromPort: aws.Int64(3),
ToPort: aws.Int64(4),
IpRanges: []*ec2.IpRange{
{
CidrIp: ipRange.CidrIp,
Description: aws.String(NLBMtuDiscoveryRuleDescription),
},
},
})
}
}
}
permsToGrant := desiredPerms.Difference(sgPerms)
permsToRevoke := sgPerms.Difference(desiredPerms)
permsToRevoke.DeleteIf(IPPermissionNotMatch{IPPermissionMatchDesc{NLBMtuDiscoveryRuleDescription}})
if len(permsToRevoke) > 0 {
permsToRevokeList := permsToRevoke.List()
changed, err := c.removeSecurityGroupIngress(sgID, permsToRevokeList)
if err != nil {
klog.Warningf("Error remove MTU permission from security group: %q", err)
return err
}
if !changed {
klog.Warning("Revoking ingress was not needed; concurrent change? groupId=", sgID)
}
id := aws.StringValue(securityGroup.GroupId)
if id == "" {
klog.Warningf("found security group without id: %v", securityGroup)
continue
sgPerms.Delete(permsToRevokeList...)
}
if len(permsToGrant) > 0 {
permsToGrantList := permsToGrant.List()
changed, err := c.addSecurityGroupIngress(sgID, permsToGrantList)
if err != nil {
klog.Warningf("Error add MTU permission to security group: %q", err)
return err
}
desiredGroupIds = append(desiredGroupIds, id)
if !changed {
klog.Warning("Allowing ingress was not needed; concurrent change? groupId=", sgID)
}
sgPerms.Insert(permsToGrantList...)
}
// Run once for Client traffic
err = c.updateInstanceSecurityGroupsForNLBTraffic(actualGroups, desiredGroupIds, trafficPorts, lbName, clientCidrs, true)
if err != nil {
return err
}
// Run once for health check traffic
err = c.updateInstanceSecurityGroupsForNLBTraffic(actualGroups, desiredGroupIds, healthCheckPorts, lbName, vpcCidrBlocks, false)
if err != nil {
return err
}
return nil
}

View File

@ -17,11 +17,9 @@ limitations under the License.
package aws
import (
"fmt"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/elb"
"github.com/stretchr/testify/assert"
)
@ -165,66 +163,6 @@ func TestIsNLB(t *testing.T) {
}
}
func TestSecurityGroupFiltering(t *testing.T) {
grid := []struct {
in []*ec2.SecurityGroup
name string
expected int
description string
}{
{
in: []*ec2.SecurityGroup{
{
IpPermissions: []*ec2.IpPermission{
{
IpRanges: []*ec2.IpRange{
{
Description: aws.String("an unmanaged"),
},
},
},
},
},
},
name: "unmanaged",
expected: 0,
description: "An environment without managed LBs should have %d, but found %d SecurityGroups",
},
{
in: []*ec2.SecurityGroup{
{
IpPermissions: []*ec2.IpPermission{
{
IpRanges: []*ec2.IpRange{
{
Description: aws.String("an unmanaged"),
},
{
Description: aws.String(fmt.Sprintf("%s=%s", NLBClientRuleDescription, "managedlb")),
},
{
Description: aws.String(fmt.Sprintf("%s=%s", NLBHealthCheckRuleDescription, "managedlb")),
},
},
},
},
},
},
name: "managedlb",
expected: 1,
description: "Found %d, but should have %d Security Groups",
},
}
for _, g := range grid {
actual := len(filterForIPRangeDescription(g.in, g.name))
if actual != g.expected {
t.Errorf(g.description, actual, g.expected)
}
}
}
func TestSyncElbListeners(t *testing.T) {
tests := []struct {
name string

View File

@ -20,12 +20,19 @@ import (
"encoding/json"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
)
// IPPermissionSet maps IP strings of strings to EC2 IpPermissions
type IPPermissionSet map[string]*ec2.IpPermission
// IPPermissionPredicate is an predicate to test whether IPPermission matches some condition.
type IPPermissionPredicate interface {
// Test checks whether specified IPPermission matches condition.
Test(perm *ec2.IpPermission) bool
}
// NewIPPermissionSet creates a new IPPermissionSet
func NewIPPermissionSet(items ...*ec2.IpPermission) IPPermissionSet {
s := make(IPPermissionSet)
@ -90,6 +97,23 @@ func (s IPPermissionSet) Insert(items ...*ec2.IpPermission) {
}
}
// Delete delete permission from the set.
func (s IPPermissionSet) Delete(items ...*ec2.IpPermission) {
for _, p := range items {
k := keyForIPPermission(p)
delete(s, k)
}
}
// DeleteIf delete permission from the set if permission matches predicate.
func (s IPPermissionSet) DeleteIf(predicate IPPermissionPredicate) {
for k, p := range s {
if predicate.Test(p) {
delete(s, k)
}
}
}
// List returns the contents as a slice. Order is not defined.
func (s IPPermissionSet) List() []*ec2.IpPermission {
res := make([]*ec2.IpPermission, 0, len(s))
@ -146,3 +170,47 @@ func keyForIPPermission(p *ec2.IpPermission) string {
}
return string(v)
}
var _ IPPermissionPredicate = IPPermissionMatchDesc{}
// IPPermissionMatchDesc checks whether specific IPPermission contains description.
type IPPermissionMatchDesc struct {
Description string
}
// Test whether specific IPPermission contains description.
func (p IPPermissionMatchDesc) Test(perm *ec2.IpPermission) bool {
for _, v4Range := range perm.IpRanges {
if aws.StringValue(v4Range.Description) == p.Description {
return true
}
}
for _, v6Range := range perm.Ipv6Ranges {
if aws.StringValue(v6Range.Description) == p.Description {
return true
}
}
for _, prefixListID := range perm.PrefixListIds {
if aws.StringValue(prefixListID.Description) == p.Description {
return true
}
}
for _, group := range perm.UserIdGroupPairs {
if aws.StringValue(group.Description) == p.Description {
return true
}
}
return false
}
var _ IPPermissionPredicate = IPPermissionNotMatch{}
// IPPermissionNotMatch is the *not* operator for Predicate
type IPPermissionNotMatch struct {
Predicate IPPermissionPredicate
}
// Test whether specific IPPermission not match the embed predicate.
func (p IPPermissionNotMatch) Test(perm *ec2.IpPermission) bool {
return !p.Predicate.Test(perm)
}