diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 3d3d2aecd9b..b3877f052bc 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -141,9 +141,14 @@ const ServiceAnnotationLoadBalancerConnectionIdleTimeout = "service.beta.kuberne const ServiceAnnotationLoadBalancerCrossZoneLoadBalancingEnabled = "service.beta.kubernetes.io/aws-load-balancer-cross-zone-load-balancing-enabled" // ServiceAnnotationLoadBalancerExtraSecurityGroups is the annotation used -// one the service to specify additional security groups to be added to ELB created +// on the service to specify additional security groups to be added to ELB created const ServiceAnnotationLoadBalancerExtraSecurityGroups = "service.beta.kubernetes.io/aws-load-balancer-extra-security-groups" +// ServiceAnnotationLoadBalancerSecurityGroups is the annotation used +// on the service to specify the security groups to be added to ELB created. Differently from the annotation +// "service.beta.kubernetes.io/aws-load-balancer-extra-security-groups", this replaces all other security groups previously assigned to the ELB. +const ServiceAnnotationLoadBalancerSecurityGroups = "service.beta.kubernetes.io/aws-load-balancer-security-groups" + // ServiceAnnotationLoadBalancerCertificate is the annotation used on the // service to request a secure listener. Value is a valid certificate ARN. // For more, see http://docs.aws.amazon.com/ElasticLoadBalancing/latest/DeveloperGuide/elb-listener-config.html @@ -3183,7 +3188,8 @@ func getPortSets(annotation string) (ports *portSets) { // attached to ELB created by a service. List always consist of at least // 1 member which is an SG created for this service or a SG from the Global config. // Extra groups can be specified via annotation, as can extra tags for any -// new groups. +// new groups. The annotation "ServiceAnnotationLoadBalancerSecurityGroups" allows for +// setting the security groups specified. func (c *Cloud) buildELBSecurityGroupList(serviceName types.NamespacedName, loadBalancerName string, annotations map[string]string) ([]string, error) { var err error var securityGroupID string @@ -3200,7 +3206,20 @@ func (c *Cloud) buildELBSecurityGroupList(serviceName types.NamespacedName, load return nil, err } } - sgList := []string{securityGroupID} + + sgList := []string{} + + for _, extraSG := range strings.Split(annotations[ServiceAnnotationLoadBalancerSecurityGroups], ",") { + extraSG = strings.TrimSpace(extraSG) + if len(extraSG) > 0 { + sgList = append(sgList, extraSG) + } + } + + // If no Security Groups have been specified with the ServiceAnnotationLoadBalancerSecurityGroups annotation, we add the default one. + if len(sgList) == 0 { + sgList = append(sgList, securityGroupID) + } for _, extraSG := range strings.Split(annotations[ServiceAnnotationLoadBalancerExtraSecurityGroups], ",") { extraSG = strings.TrimSpace(extraSG) diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 868e6fcd7de..9ff80e4c35c 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -1198,6 +1198,38 @@ func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { } } +func TestLBSecurityGroupsAnnotation(t *testing.T) { + awsServices := newMockedFakeAWSServices(TestClusterId) + c, _ := newAWSCloud(CloudConfig{}, awsServices) + + sg1 := map[string]string{ServiceAnnotationLoadBalancerSecurityGroups: "sg-000001"} + sg2 := map[string]string{ServiceAnnotationLoadBalancerSecurityGroups: "sg-000002"} + sg3 := map[string]string{ServiceAnnotationLoadBalancerSecurityGroups: "sg-000001, sg-000002"} + + tests := []struct { + name string + + annotations map[string]string + expectedSGs []string + }{ + {"SG specified", sg1, []string{sg1[ServiceAnnotationLoadBalancerSecurityGroups]}}, + {"Multiple SGs specified", sg3, []string{sg1[ServiceAnnotationLoadBalancerSecurityGroups], sg2[ServiceAnnotationLoadBalancerSecurityGroups]}}, + } + + awsServices.ec2.(*MockedFakeEC2).expectDescribeSecurityGroups(TestClusterId, "k8s-elb-aid", "cluster.test") + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serviceName := types.NamespacedName{Namespace: "default", Name: "myservice"} + + sgList, err := c.buildELBSecurityGroupList(serviceName, "aid", test.annotations) + assert.NoError(t, err, "buildELBSecurityGroupList failed") + assert.True(t, sets.NewString(test.expectedSGs...).Equal(sets.NewString(sgList...)), + "Security Groups expected=%q , returned=%q", test.expectedSGs, sgList) + }) + } +} + // Test that we can add a load balancer tag func TestAddLoadBalancerTags(t *testing.T) { loadBalancerName := "test-elb"