diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 72c515abe9d..b21b8bce787 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -41,6 +41,7 @@ import ( "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/cloudprovider" + "k8s.io/kubernetes/pkg/util/sets" "github.com/golang/glog" ) @@ -1573,6 +1574,40 @@ func (s *AWSCloud) createTags(request *ec2.CreateTagsInput) (*ec2.CreateTagsOutp } } +func (s *AWSCloud) listSubnetIDsinVPC(vpc *ec2.Vpc) ([]string, error) { + + subnetIds := []string{} + + request := &ec2.DescribeSubnetsInput{} + filters := []*ec2.Filter{} + filters = append(filters, newEc2Filter("vpc-id", orEmpty(vpc.VpcId))) + // Note, this will only return subnets tagged with the cluster identifier for this Kubernetes cluster. + // In the case where an AZ has public & private subnets per AWS best practices, the deployment should ensure + // only the public subnet (where the ELB will go) is so tagged. + filters = s.addFilters(filters) + request.Filters = filters + + subnets, err := s.ec2.DescribeSubnets(request) + if err != nil { + glog.Error("error describing subnets: ", err) + return nil, err + } + + availabilityZones := sets.NewString() + for _, subnet := range subnets { + az := orEmpty(subnet.AvailabilityZone) + id := orEmpty(subnet.SubnetId) + if availabilityZones.Has(az) { + glog.Warning("Found multiple subnets per AZ '", az, "', ignoring subnet '", id, "'") + continue + } + subnetIds = append(subnetIds, id) + availabilityZones.Insert(az) + } + + return subnetIds, nil +} + // EnsureTCPLoadBalancer implements TCPLoadBalancer.EnsureTCPLoadBalancer // TODO(justinsb) It is weird that these take a region. I suspect it won't work cross-region anwyay. func (s *AWSCloud) EnsureTCPLoadBalancer(name, region string, publicIP net.IP, ports []*api.ServicePort, hosts []string, affinity api.ServiceAffinity) (*api.LoadBalancerStatus, error) { @@ -1606,32 +1641,10 @@ func (s *AWSCloud) EnsureTCPLoadBalancer(name, region string, publicIP net.IP, p } // Construct list of configured subnets - subnetIDs := []string{} - { - request := &ec2.DescribeSubnetsInput{} - filters := []*ec2.Filter{} - filters = append(filters, newEc2Filter("vpc-id", orEmpty(vpc.VpcId))) - // Note, this will only return subnets tagged with the cluster identifier for this Kubernetes cluster. - // In the case where an AZ has public & private subnets per AWS best practices, the deployment should ensure - // only the public subnet (where the ELB will go) is so tagged. - filters = s.addFilters(filters) - request.Filters = filters - - subnets, err := s.ec2.DescribeSubnets(request) - if err != nil { - glog.Error("Error describing subnets: ", err) - return nil, err - } - - // zones := []string{} - for _, subnet := range subnets { - subnetIDs = append(subnetIDs, orEmpty(subnet.SubnetId)) - if !strings.HasPrefix(orEmpty(subnet.AvailabilityZone), region) { - glog.Error("Found AZ that did not match region", orEmpty(subnet.AvailabilityZone), " vs ", region) - return nil, fmt.Errorf("invalid AZ for region") - } - // zones = append(zones, subnet.AvailabilityZone) - } + subnetIDs, err := s.listSubnetIDsinVPC(vpc) + if err != nil { + glog.Error("error listing subnets in VPC", err) + return nil, err } // Create a security group for the load balancer diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index bb656bf4feb..51e97cd54a6 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -247,7 +247,9 @@ func TestNewAWSCloud(t *testing.T) { } type FakeEC2 struct { - aws *FakeAWSServices + aws *FakeAWSServices + Subnets []*ec2.Subnet + DescribeSubnetsInput *ec2.DescribeSubnetsInput } func contains(haystack []*string, needle string) bool { @@ -385,8 +387,9 @@ func (ec2 *FakeEC2) DescribeVPCs(*ec2.DescribeVpcsInput) ([]*ec2.Vpc, error) { panic("Not implemented") } -func (ec2 *FakeEC2) DescribeSubnets(*ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) { - panic("Not implemented") +func (ec2 *FakeEC2) DescribeSubnets(request *ec2.DescribeSubnetsInput) ([]*ec2.Subnet, error) { + ec2.DescribeSubnetsInput = request + return ec2.Subnets, nil } func (ec2 *FakeEC2) CreateTags(*ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) { @@ -697,3 +700,92 @@ func TestLoadBalancerMatchesClusterRegion(t *testing.T) { t.Errorf("Expected UpdateTCPLoadBalancer region mismatch error.") } } + +func constructSubnets(subnetsIn map[int]map[string]string) (subnetsOut []*ec2.Subnet) { + for i := range subnetsIn { + subnetsOut = append( + subnetsOut, + constructSubnet( + subnetsIn[i]["id"], + subnetsIn[i]["az"], + ), + ) + } + return +} + +func constructSubnet(id string, az string) *ec2.Subnet { + return &ec2.Subnet{ + SubnetId: &id, + AvailabilityZone: &az, + } +} + +func TestSubnetIDsinVPC(t *testing.T) { + awsServices := NewFakeAWSServices() + c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + if err != nil { + t.Errorf("Error building aws cloud: %v", err) + return + } + + vpcID := "vpc-deadbeef" + vpc := &ec2.Vpc{ + VpcId: &vpcID, + } + + // test with 3 subnets from 3 different AZs + subnets := make(map[int]map[string]string) + subnets[0] = make(map[string]string) + subnets[0]["id"] = "subnet-a0000001" + subnets[0]["az"] = "af-south-1a" + subnets[1] = make(map[string]string) + subnets[1]["id"] = "subnet-b0000001" + subnets[1]["az"] = "af-south-1b" + subnets[2] = make(map[string]string) + subnets[2]["id"] = "subnet-c0000001" + subnets[2]["az"] = "af-south-1c" + awsServices.ec2.Subnets = constructSubnets(subnets) + + result, err := c.listSubnetIDsinVPC(vpc) + if err != nil { + t.Errorf("Error listing subnets: %v", err) + return + } + + if len(result) != 3 { + t.Errorf("Expected 3 subnets but got %d", len(result)) + return + } + + result_set := make(map[string]bool) + for _, v := range result { + result_set[v] = true + } + + for i := range subnets { + if !result_set[subnets[i]["id"]] { + t.Errorf("Expected subnet%d '%s' in result: %v", i, subnets[i]["id"], result) + return + } + } + + // test with 4 subnets from 3 different AZs + // add duplicate az subnet + subnets[3] = make(map[string]string) + subnets[3]["id"] = "subnet-c0000002" + subnets[3]["az"] = "af-south-1c" + awsServices.ec2.Subnets = constructSubnets(subnets) + + result, err = c.listSubnetIDsinVPC(vpc) + if err != nil { + t.Errorf("Error listing subnets: %v", err) + return + } + + if len(result) != 3 { + t.Errorf("Expected 3 subnets but got %d", len(result)) + return + } + +}