From 62a8b3d44cbfc4eba5c6cdda6ebebab48533a8c4 Mon Sep 17 00:00:00 2001 From: Jerome Touffe-Blin Date: Sun, 29 Nov 2015 22:14:19 +1100 Subject: [PATCH] Fix #17912 - pick public subnets only on ELB creation --- pkg/cloudprovider/providers/aws/aws.go | 54 ++++++++++--- pkg/cloudprovider/providers/aws/aws_test.go | 90 +++++++++++++++++++-- 2 files changed, 128 insertions(+), 16 deletions(-) diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 811c2e107e8..dd98b82bad4 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -1782,29 +1782,38 @@ func (s *AWSCloud) createTags(request *ec2.CreateTagsInput) (*ec2.CreateTagsOutp } } -func (s *AWSCloud) listSubnetIDsinVPC(vpcId string) ([]string, error) { - +func (s *AWSCloud) listPublicSubnetIDsinVPC(vpcId string) ([]string, error) { subnetIds := []string{} - request := &ec2.DescribeSubnetsInput{} + sRequest := &ec2.DescribeSubnetsInput{} filters := []*ec2.Filter{} filters = append(filters, newEc2Filter("vpc-id", vpcId)) - // Note, this will only return subnets tagged with the cluster identifier for this Kubernetes cluster. - // In the case where an AZ has public & private subnets per AWS best practices, the deployment should ensure - // only the public subnet (where the ELB will go) is so tagged. filters = s.addFilters(filters) - request.Filters = filters + sRequest.Filters = filters - subnets, err := s.ec2.DescribeSubnets(request) + subnets, err := s.ec2.DescribeSubnets(sRequest) if err != nil { glog.Error("Error describing subnets: ", err) return nil, err } + rRequest := &ec2.DescribeRouteTablesInput{} + rRequest.Filters = filters + + rt, err := s.ec2.DescribeRouteTables(rRequest) + if err != nil { + glog.Error("error describing route tables: ", err) + return nil, err + } + availabilityZones := sets.NewString() for _, subnet := range subnets { az := orEmpty(subnet.AvailabilityZone) id := orEmpty(subnet.SubnetId) + if !isSubnetPublic(rt, id) { + glog.V(2).Infof("Ignoring private subnet %q", id) + continue + } if availabilityZones.Has(az) { glog.Warning("Found multiple subnets per AZ '", az, "', ignoring subnet '", id, "'") continue @@ -1816,6 +1825,33 @@ func (s *AWSCloud) listSubnetIDsinVPC(vpcId string) ([]string, error) { return subnetIds, nil } +func isSubnetPublic(rt []*ec2.RouteTable, subnetID string) bool { + for _, table := range rt { + var found bool + for _, assoc := range table.Associations { + if aws.StringValue(assoc.SubnetId) == subnetID { + found = true + break + } + } + if !found { + continue + } + for _, route := range table.Routes { + // There is no direct way in the AWS API to determine if a subnet is public or private. + // A public subnet is one which has an internet gateway route + // we look for the gatewayId and make sure it has the prefix of igw to differentiate + // from the default in-subnet route which is called "local" + // or other virtual gateway (starting with vgv) + // or vpc peering connections (starting with pcx). + if strings.HasPrefix(aws.StringValue(route.GatewayId), "igw") { + return true + } + } + } + return false +} + // EnsureLoadBalancer implements LoadBalancer.EnsureLoadBalancer // TODO(justinsb) It is weird that these take a region. I suspect it won't work cross-region anyway. func (s *AWSCloud) EnsureLoadBalancer(name, region string, publicIP net.IP, ports []*api.ServicePort, hosts []string, serviceName types.NamespacedName, affinity api.ServiceAffinity) (*api.LoadBalancerStatus, error) { @@ -1856,7 +1892,7 @@ func (s *AWSCloud) EnsureLoadBalancer(name, region string, publicIP net.IP, port } // Construct list of configured subnets - subnetIDs, err := s.listSubnetIDsinVPC(vpcId) + subnetIDs, err := s.listPublicSubnetIDsinVPC(vpcId) if err != nil { glog.Error("Error listing subnets in VPC", err) return nil, err diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 3151c1246bc..8ccd2f27832 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -252,9 +252,11 @@ func TestNewAWSCloud(t *testing.T) { } type FakeEC2 struct { - aws *FakeAWSServices - Subnets []*ec2.Subnet - DescribeSubnetsInput *ec2.DescribeSubnetsInput + aws *FakeAWSServices + Subnets []*ec2.Subnet + DescribeSubnetsInput *ec2.DescribeSubnetsInput + RouteTables []*ec2.RouteTable + DescribeRouteTablesInput *ec2.DescribeRouteTablesInput } func contains(haystack []*string, needle string) bool { @@ -401,8 +403,9 @@ func (ec2 *FakeEC2) CreateTags(*ec2.CreateTagsInput) (*ec2.CreateTagsOutput, err panic("Not implemented") } -func (s *FakeEC2) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ([]*ec2.RouteTable, error) { - panic("Not implemented") +func (ec2 *FakeEC2) DescribeRouteTables(request *ec2.DescribeRouteTablesInput) ([]*ec2.RouteTable, error) { + ec2.DescribeRouteTablesInput = request + return ec2.RouteTables, nil } func (s *FakeEC2) CreateRoute(request *ec2.CreateRouteInput) (*ec2.CreateRouteOutput, error) { @@ -751,6 +754,35 @@ func constructSubnet(id string, az string) *ec2.Subnet { } } +func constructRouteTables(routeTablesIn map[string]bool) (routeTablesOut []*ec2.RouteTable) { + for subnetID := range routeTablesIn { + routeTablesOut = append( + routeTablesOut, + constructRouteTable( + subnetID, + routeTablesIn[subnetID], + ), + ) + } + return +} + +func constructRouteTable(subnetID string, public bool) *ec2.RouteTable { + var gatewayID string + if public { + gatewayID = "igw-" + subnetID[len(subnetID)-8:8] + } else { + gatewayID = "vgw-" + subnetID[len(subnetID)-8:8] + } + return &ec2.RouteTable{ + Associations: []*ec2.RouteTableAssociation{{SubnetId: aws.String(subnetID)}}, + Routes: []*ec2.Route{{ + DestinationCidrBlock: aws.String("0.0.0.0/0"), + GatewayId: aws.String(gatewayID), + }}, + } +} + func TestSubnetIDsinVPC(t *testing.T) { awsServices := NewFakeAWSServices() c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) @@ -774,7 +806,14 @@ func TestSubnetIDsinVPC(t *testing.T) { subnets[2]["az"] = "af-south-1c" awsServices.ec2.Subnets = constructSubnets(subnets) - result, err := c.listSubnetIDsinVPC(vpcID) + routeTables := map[string]bool{ + "subnet-a0000001": true, + "subnet-b0000001": true, + "subnet-c0000001": true, + } + awsServices.ec2.RouteTables = constructRouteTables(routeTables) + + result, err := c.listPublicSubnetIDsinVPC(vpcID) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -803,8 +842,10 @@ func TestSubnetIDsinVPC(t *testing.T) { subnets[3]["id"] = "subnet-c0000002" subnets[3]["az"] = "af-south-1c" awsServices.ec2.Subnets = constructSubnets(subnets) + routeTables["subnet-c0000002"] = true + awsServices.ec2.RouteTables = constructRouteTables(routeTables) - result, err = c.listSubnetIDsinVPC(vpcID) + result, err = c.listPublicSubnetIDsinVPC(vpcID) if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -815,6 +856,41 @@ func TestSubnetIDsinVPC(t *testing.T) { return } + // test with 6 subnets from 3 different AZs + // with 3 private subnets + subnets[4] = make(map[string]string) + subnets[4]["id"] = "subnet-d0000001" + subnets[4]["az"] = "af-south-1a" + subnets[5] = make(map[string]string) + subnets[5]["id"] = "subnet-d0000002" + subnets[5]["az"] = "af-south-1b" + + awsServices.ec2.Subnets = constructSubnets(subnets) + routeTables["subnet-a0000001"] = false + routeTables["subnet-b0000001"] = false + routeTables["subnet-c0000001"] = false + routeTables["subnet-c0000002"] = true + routeTables["subnet-d0000001"] = true + routeTables["subnet-d0000002"] = true + awsServices.ec2.RouteTables = constructRouteTables(routeTables) + result, err = c.listPublicSubnetIDsinVPC(vpcID) + if err != nil { + t.Errorf("Error listing subnets: %v", err) + return + } + + if len(result) != 3 { + t.Errorf("Expected 3 subnets but got %d", len(result)) + return + } + + expected := []*string{aws.String("subnet-c0000002"), aws.String("subnet-d0000001"), aws.String("subnet-d0000002")} + for _, s := range result { + if !contains(expected, s) { + t.Errorf("Unexpected subnet '%s' found", s) + return + } + } } func TestIpPermissionExistsHandlesMultipleGroupIds(t *testing.T) {