diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 8fd1891d04a..b2f1dd1b3d9 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -111,6 +111,10 @@ const ServiceAnnotationLoadBalancerConnectionIdleTimeout = "service.beta.kuberne // used on the service to enable or disable cross-zone load balancing. const ServiceAnnotationLoadBalancerCrossZoneLoadBalancingEnabled = "service.beta.kubernetes.io/aws-load-balancer-cross-zone-load-balancing-enabled" +// ServiceAnnotationLoadBalancerExtraSecurityGroups is the annotation used +// one the service to specify additional security groups to be added to ELB created +const ServiceAnnotationLoadBalancerExtraSecurityGroups = "service.beta.kubernetes.io/aws-load-balancer-extra-security-groups" + // ServiceAnnotationLoadBalancerCertificate is the annotation used on the // service to request a secure listener. Value is a valid certificate ARN. // For more, see http://docs.aws.amazon.com/ElasticLoadBalancing/latest/DeveloperGuide/elb-listener-config.html @@ -2545,6 +2549,38 @@ func getPortSets(annotation string) (ports *portSets) { return } +// buildELBSecurityGroupList returns list of SecurityGroups which should be +// attached to ELB created by a service. List always consist of at least +// 1 member which is an SG created for this service. Extra groups can be +// specified via annotation +func (c *Cloud) buildELBSecurityGroupList(serviceName types.NamespacedName, loadBalancerName, annotation string) ([]string, error) { + var err error + var securityGroupID string + + if c.cfg.Global.ElbSecurityGroup != "" { + securityGroupID = c.cfg.Global.ElbSecurityGroup + } else { + // Create a security group for the load balancer + sgName := "k8s-elb-" + loadBalancerName + sgDescription := fmt.Sprintf("Security group for Kubernetes ELB %s (%v)", loadBalancerName, serviceName) + securityGroupID, err = c.ensureSecurityGroup(sgName, sgDescription) + if err != nil { + glog.Error("Error creating load balancer security group: ", err) + return nil, err + } + } + sgList := []string{securityGroupID} + + for _, extraSG := range strings.Split(annotation, ",") { + extraSG = strings.TrimSpace(extraSG) + if len(extraSG) > 0 { + sgList = append(sgList, extraSG) + } + } + + return sgList, nil +} + // buildListener creates a new listener from the given port, adding an SSL certificate // if indicated by the appropriate annotations. func buildListener(port v1.ServicePort, annotations map[string]string, sslPorts *portSets) (*elb.Listener, error) { @@ -2765,61 +2801,50 @@ func (c *Cloud) EnsureLoadBalancer(clusterName string, apiService *v1.Service, n loadBalancerName := cloudprovider.GetLoadBalancerName(apiService) serviceName := types.NamespacedName{Namespace: apiService.Namespace, Name: apiService.Name} + securityGroupIDs, err := c.buildELBSecurityGroupList(serviceName, loadBalancerName, annotations[ServiceAnnotationLoadBalancerExtraSecurityGroups]) + if err != nil { + return nil, err + } + if len(securityGroupIDs) == 0 { + return nil, fmt.Errorf("[BUG] ELB can't have empty list of Security Groups to be assigned, this is a Kubernetes bug, please report") + } - // Create a security group for the load balancer - var securityGroupID string { - if c.cfg.Global.ElbSecurityGroup != "" { - securityGroupID = c.cfg.Global.ElbSecurityGroup + ec2SourceRanges := []*ec2.IpRange{} + for _, sourceRange := range sourceRanges.StringSlice() { + ec2SourceRanges = append(ec2SourceRanges, &ec2.IpRange{CidrIp: aws.String(sourceRange)}) + } - } else { + permissions := NewIPPermissionSet() + for _, port := range apiService.Spec.Ports { + portInt64 := int64(port.Port) + protocol := strings.ToLower(string(port.Protocol)) - sgName := "k8s-elb-" + loadBalancerName - sgDescription := fmt.Sprintf("Security group for Kubernetes ELB %s (%v)", loadBalancerName, serviceName) - securityGroupID, err = c.ensureSecurityGroup(sgName, sgDescription) - if err != nil { - glog.Error("Error creating load balancer security group: ", err) - return nil, err + permission := &ec2.IpPermission{} + permission.FromPort = &portInt64 + permission.ToPort = &portInt64 + permission.IpRanges = ec2SourceRanges + permission.IpProtocol = &protocol + + permissions.Insert(permission) + } + + // Allow ICMP fragmentation packets, important for MTU discovery + { + permission := &ec2.IpPermission{ + IpProtocol: aws.String("icmp"), + FromPort: aws.Int64(3), + ToPort: aws.Int64(4), + IpRanges: []*ec2.IpRange{{CidrIp: aws.String("0.0.0.0/0")}}, } - ec2SourceRanges := []*ec2.IpRange{} - for _, sourceRange := range sourceRanges.StringSlice() { - ec2SourceRanges = append(ec2SourceRanges, &ec2.IpRange{CidrIp: aws.String(sourceRange)}) - } - - permissions := NewIPPermissionSet() - for _, port := range apiService.Spec.Ports { - portInt64 := int64(port.Port) - protocol := strings.ToLower(string(port.Protocol)) - - permission := &ec2.IpPermission{} - permission.FromPort = &portInt64 - permission.ToPort = &portInt64 - permission.IpRanges = ec2SourceRanges - permission.IpProtocol = &protocol - - permissions.Insert(permission) - } - - // Allow ICMP fragmentation packets, important for MTU discovery - { - permission := &ec2.IpPermission{ - IpProtocol: aws.String("icmp"), - FromPort: aws.Int64(3), - ToPort: aws.Int64(4), - IpRanges: []*ec2.IpRange{{CidrIp: aws.String("0.0.0.0/0")}}, - } - - permissions.Insert(permission) - } - - _, err = c.setSecurityGroupIngress(securityGroupID, permissions) - if err != nil { - return nil, err - } + permissions.Insert(permission) + } + _, err = c.setSecurityGroupIngress(securityGroupIDs[0], permissions) + if err != nil { + return nil, err } } - securityGroupIDs := []string{securityGroupID} // Build the load balancer itself loadBalancer, err := c.ensureLoadBalancer( diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index a50eac1fc81..05169f72e37 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "fmt" "io" "reflect" "strings" @@ -386,8 +387,9 @@ func (ec2 *FakeEC2) DeleteVolume(request *ec2.DeleteVolumeInput) (resp *ec2.Dele panic("Not implemented") } -func (ec2 *FakeEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { - panic("Not implemented") +func (e *FakeEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { + args := e.Called(request) + return args.Get(0).([]*ec2.SecurityGroup), nil } func (ec2 *FakeEC2) CreateSecurityGroup(*ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) { @@ -1095,6 +1097,18 @@ func (self *FakeELB) expectDescribeLoadBalancers(loadBalancerName string) { }) } +func (self *FakeEC2) expectDescribeSecurityGroups(groupName, clusterID string) { + tags := []*ec2.Tag{ + {Key: aws.String(TagNameKubernetesClusterLegacy), Value: aws.String(TestClusterId)}, + {Key: aws.String(fmt.Sprintf("%s%s", TagNameKubernetesClusterPrefix, TestClusterId)), Value: aws.String(ResourceLifecycleOwned)}, + } + + self.On("DescribeSecurityGroups", &ec2.DescribeSecurityGroupsInput{Filters: []*ec2.Filter{ + newEc2Filter("group-name", groupName), + newEc2Filter("vpc-id", ""), + }}).Return([]*ec2.SecurityGroup{{Tags: tags}}) +} + func TestDescribeLoadBalancerOnDelete(t *testing.T) { awsServices := NewFakeAWSServices() c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) @@ -1350,3 +1364,37 @@ func TestGetLoadBalancerAdditionalTags(t *testing.T) { } } } + +func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { + awsServices := NewFakeAWSServices() + c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + + sg1 := "sg-000001" + sg2 := "sg-000002" + + tests := []struct { + name string + + extraSGsAnnotation string + expectedSGs []string + }{ + {"No extra SG annotation", "", []string{}}, + {"Empty extra SGs specified", ", ,,", []string{}}, + {"SG specified", sg1, []string{sg1}}, + {"Multiple SGs specified", fmt.Sprintf("%s, %s", sg1, sg2), []string{sg1, sg2}}, + } + + awsServices.ec2.expectDescribeSecurityGroups("k8s-elb-aid", "cluster.test") + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serviceName := types.NamespacedName{Namespace: "default", Name: "myservice"} + + sgList, err := c.buildELBSecurityGroupList(serviceName, "aid", test.extraSGsAnnotation) + assert.NoError(t, err, "buildELBSecurityGroupList failed") + extraSGs := sgList[1:] + assert.True(t, sets.NewString(test.expectedSGs...).Equal(sets.NewString(extraSGs...)), + "Security Groups expected=%q , returned=%q", test.expectedSGs, extraSGs) + }) + } +}