diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index e8ea3a46e3d..0758745f13d 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -1596,10 +1596,80 @@ func isEqualUserGroupPair(l, r *ec2.UserIdGroupPair, compareGroupUserIDs bool) b return false } +// Makes sure the security group ingress is exactly the specified permissions +// Returns true if and only if changes were made +// The security group must already exist +func (s *AWSCloud) setSecurityGroupIngress(securityGroupId string, permissions IPPermissionSet) (bool, error) { + group, err := s.findSecurityGroup(securityGroupId) + if err != nil { + glog.Warning("Error retrieving security group", err) + return false, err + } + + if group == nil { + return false, fmt.Errorf("security group not found: %s", securityGroupId) + } + + glog.V(2).Infof("Existing security group ingress: %s %v", securityGroupId, group.IpPermissions) + + actual := NewIPPermissionSet(group.IpPermissions...) + + // EC2 groups rules together, for example combining: + // + // { Port=80, Range=[A] } and { Port=80, Range=[B] } + // + // into { Port=80, Range=[A,B] } + // + // We have to ungroup them, because otherwise the logic becomes really + // complicated, and also because if we have Range=[A,B] and we try to + // add Range=[A] then EC2 complains about a duplicate rule. + permissions = permissions.Ungroup() + actual = actual.Ungroup() + + remove := actual.Difference(permissions) + add := permissions.Difference(actual) + + if add.Len() == 0 && remove.Len() == 0 { + return false, nil + } + + // TODO: There is a limit in VPC of 100 rules per security group, so we + // probably should try grouping or combining to fit under this limit. + // But this is only used on the ELB security group currently, so it + // would require (ports * CIDRS) > 100. Also, it isn't obvious exactly + // how removing single permissions from compound rules works, and we + // don't want to accidentally open more than intended while we're + // applying changes. + if add.Len() != 0 { + glog.V(2).Infof("Adding security group ingress: %s %v", securityGroupId, add.List()) + + request := &ec2.AuthorizeSecurityGroupIngressInput{} + request.GroupId = &securityGroupId + request.IpPermissions = add.List() + _, err = s.ec2.AuthorizeSecurityGroupIngress(request) + if err != nil { + return false, fmt.Errorf("error authorizing security group ingress: %v", err) + } + } + if remove.Len() != 0 { + glog.V(2).Infof("Remove security group ingress: %s %v", securityGroupId, remove.List()) + + request := &ec2.RevokeSecurityGroupIngressInput{} + request.GroupId = &securityGroupId + request.IpPermissions = remove.List() + _, err = s.ec2.RevokeSecurityGroupIngress(request) + if err != nil { + return false, fmt.Errorf("error revoking security group ingress: %v", err) + } + } + + return true, nil +} + // Makes sure the security group includes the specified permissions // Returns true if and only if changes were made // The security group must already exist -func (s *AWSCloud) ensureSecurityGroupIngress(securityGroupId string, addPermissions []*ec2.IpPermission) (bool, error) { +func (s *AWSCloud) addSecurityGroupIngress(securityGroupId string, addPermissions []*ec2.IpPermission) (bool, error) { group, err := s.findSecurityGroup(securityGroupId) if err != nil { glog.Warning("Error retrieving security group", err) @@ -2020,7 +2090,7 @@ func (s *AWSCloud) EnsureLoadBalancer(name, region string, publicIP net.IP, port ec2SourceRanges = append(ec2SourceRanges, &ec2.IpRange{CidrIp: aws.String(sourceRange)}) } - permissions := []*ec2.IpPermission{} + permissions := NewIPPermissionSet() for _, port := range ports { portInt64 := int64(port.Port) protocol := strings.ToLower(string(port.Protocol)) @@ -2031,9 +2101,9 @@ func (s *AWSCloud) EnsureLoadBalancer(name, region string, publicIP net.IP, port permission.IpRanges = ec2SourceRanges permission.IpProtocol = &protocol - permissions = append(permissions, permission) + permissions.Insert(permission) } - _, err = s.ensureSecurityGroupIngress(securityGroupID, permissions) + _, err = s.setSecurityGroupIngress(securityGroupID, permissions) if err != nil { return nil, err } @@ -2285,7 +2355,7 @@ func (s *AWSCloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalan permissions := []*ec2.IpPermission{permission} if add { - changed, err := s.ensureSecurityGroupIngress(instanceSecurityGroupId, permissions) + changed, err := s.addSecurityGroupIngress(instanceSecurityGroupId, permissions) if err != nil { return err } diff --git a/pkg/cloudprovider/providers/aws/sets_ippermissions.go b/pkg/cloudprovider/providers/aws/sets_ippermissions.go new file mode 100644 index 00000000000..2e1343ff8d9 --- /dev/null +++ b/pkg/cloudprovider/providers/aws/sets_ippermissions.go @@ -0,0 +1,146 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "encoding/json" + "fmt" + + "github.com/aws/aws-sdk-go/service/ec2" +) + +type IPPermissionSet map[string]*ec2.IpPermission + +func NewIPPermissionSet(items ...*ec2.IpPermission) IPPermissionSet { + s := make(IPPermissionSet) + s.Insert(items...) + return s +} + +// Ungroup splits permissions out into individual permissions +// EC2 will combine permissions with the same port but different SourceRanges together, for example +// We ungroup them so we can process them +func (s IPPermissionSet) Ungroup() IPPermissionSet { + l := []*ec2.IpPermission{} + for _, p := range s.List() { + if len(p.IpRanges) <= 1 { + l = append(l, p) + continue + } + for _, ipRange := range p.IpRanges { + c := &ec2.IpPermission{} + *c = *p + c.IpRanges = []*ec2.IpRange{ipRange} + l = append(l, c) + } + } + + l2 := []*ec2.IpPermission{} + for _, p := range l { + if len(p.UserIdGroupPairs) <= 1 { + l2 = append(l2, p) + continue + } + for _, u := range p.UserIdGroupPairs { + c := &ec2.IpPermission{} + *c = *p + c.UserIdGroupPairs = []*ec2.UserIdGroupPair{u} + l2 = append(l, c) + } + } + + l3 := []*ec2.IpPermission{} + for _, p := range l2 { + if len(p.PrefixListIds) <= 1 { + l3 = append(l3, p) + continue + } + for _, v := range p.PrefixListIds { + c := &ec2.IpPermission{} + *c = *p + c.PrefixListIds = []*ec2.PrefixListId{v} + l3 = append(l3, c) + } + } + + return NewIPPermissionSet(l3...) +} + +// Insert adds items to the set. +func (s IPPermissionSet) Insert(items ...*ec2.IpPermission) { + for _, p := range items { + k := keyForIPPermission(p) + s[k] = p + } +} + +// List returns the contents as a slice. Order is not defined. +func (s IPPermissionSet) List() []*ec2.IpPermission { + res := make([]*ec2.IpPermission, 0, len(s)) + for _, v := range s { + res = append(res, v) + } + return res +} + +// IsSuperset returns true if and only if s1 is a superset of s2. +func (s1 IPPermissionSet) IsSuperset(s2 IPPermissionSet) bool { + for k := range s2 { + _, found := s1[k] + if !found { + return false + } + } + return true +} + +// Equal returns true if and only if s1 is equal (as a set) to s2. +// Two sets are equal if their membership is identical. +// (In practice, this means same elements, order doesn't matter) +func (s1 IPPermissionSet) Equal(s2 IPPermissionSet) bool { + return len(s1) == len(s2) && s1.IsSuperset(s2) +} + +// Difference returns a set of objects that are not in s2 +// For example: +// s1 = {a1, a2, a3} +// s2 = {a1, a2, a4, a5} +// s1.Difference(s2) = {a3} +// s2.Difference(s1) = {a4, a5} +func (s IPPermissionSet) Difference(s2 IPPermissionSet) IPPermissionSet { + result := NewIPPermissionSet() + for k, v := range s { + _, found := s2[k] + if !found { + result[k] = v + } + } + return result +} + +// Len returns the size of the set. +func (s IPPermissionSet) Len() int { + return len(s) +} + +func keyForIPPermission(p *ec2.IpPermission) string { + v, err := json.Marshal(p) + if err != nil { + panic(fmt.Sprintf("error building JSON representation of ec2.IpPermission: %v", err)) + } + return string(v) +}