diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 672b6214417..0dd55e59c44 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -928,12 +928,16 @@ type awsInstance struct { } // newAWSInstance creates a new awsInstance object -func newAWSInstance(ec2 EC2, instance *ec2.Instance) *awsInstance { +func newAWSInstance(ec2Service EC2, instance *ec2.Instance) *awsInstance { + az := "" + if instance.Placement != nil { + az = aws.StringValue(instance.Placement.AvailabilityZone) + } self := &awsInstance{ - ec2: ec2, + ec2: ec2Service, awsID: aws.StringValue(instance.InstanceId), nodeName: aws.StringValue(instance.PrivateDnsName), - availabilityZone: aws.StringValue(instance.Placement.AvailabilityZone), + availabilityZone: az, instanceType: aws.StringValue(instance.InstanceType), vpcID: aws.StringValue(instance.VpcId), subnetID: aws.StringValue(instance.SubnetId), diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 1004714c6b0..8df2b3cf82c 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -110,14 +110,11 @@ func TestReadAWSCloudConfig(t *testing.T) { } type FakeAWSServices struct { - availabilityZone string + region string instances []*ec2.Instance - instanceId string - privateDnsName string + selfInstance *ec2.Instance networkInterfacesMacs []string networkInterfacesVpcIDs []string - internalIP string - externalIP string ec2 *FakeEC2 elb *FakeELB @@ -127,7 +124,7 @@ type FakeAWSServices struct { func NewFakeAWSServices() *FakeAWSServices { s := &FakeAWSServices{} - s.availabilityZone = "us-east-1a" + s.region = "us-east-1" s.ec2 = &FakeEC2{aws: s} s.elb = &FakeELB{aws: s} s.asg = &FakeASG{aws: s} @@ -136,14 +133,16 @@ func NewFakeAWSServices() *FakeAWSServices { s.networkInterfacesMacs = []string{"aa:bb:cc:dd:ee:00", "aa:bb:cc:dd:ee:01"} s.networkInterfacesVpcIDs = []string{"vpc-mac0", "vpc-mac1"} - s.instanceId = "i-self" - s.privateDnsName = "ip-172-20-0-100.ec2.internal" - s.internalIP = "192.168.0.1" - s.externalIP = "1.2.3.4" - var selfInstance ec2.Instance - selfInstance.InstanceId = &s.instanceId - selfInstance.PrivateDnsName = &s.privateDnsName - s.instances = []*ec2.Instance{&selfInstance} + selfInstance := &ec2.Instance{} + selfInstance.InstanceId = aws.String("i-self") + selfInstance.Placement = &ec2.Placement{ + AvailabilityZone: aws.String("us-east-1a"), + } + selfInstance.PrivateDnsName = aws.String("ip-172-20-0-100.ec2.internal") + selfInstance.PrivateIpAddress = aws.String("192.168.0.1") + selfInstance.PublicIpAddress = aws.String("1.2.3.4") + s.selfInstance = selfInstance + s.instances = []*ec2.Instance{selfInstance} var tag ec2.Tag tag.Key = aws.String(TagNameKubernetesCluster) @@ -154,12 +153,10 @@ func NewFakeAWSServices() *FakeAWSServices { } func (s *FakeAWSServices) withAz(az string) *FakeAWSServices { - s.availabilityZone = az - return s -} - -func (s *FakeAWSServices) withInstances(instances []*ec2.Instance) *FakeAWSServices { - s.instances = instances + if s.selfInstance.Placement == nil { + s.selfInstance.Placement = &ec2.Placement{} + } + s.selfInstance.Placement.AvailabilityZone = aws.String(az) return s } @@ -205,7 +202,7 @@ func TestNewAWSCloud(t *testing.T) { awsServices AWSServices expectError bool - zone string + region string }{ { "No config reader", @@ -220,14 +217,13 @@ func TestNewAWSCloud(t *testing.T) { { "Config specifies valid zone", strings.NewReader("[global]\nzone = eu-west-1a"), NewFakeAWSServices(), - false, "eu-west-1a", + false, "eu-west-1", }, { "Gets zone from metadata when not in config", - strings.NewReader("[global]\n"), NewFakeAWSServices(), - false, "us-east-1a", + false, "us-east-1", }, { "No zone in config or metadata", @@ -247,9 +243,9 @@ func TestNewAWSCloud(t *testing.T) { } else { if err != nil { t.Errorf("Should succeed for case: %s, got %v", test.name, err) - } else if c.availabilityZone != test.zone { - t.Errorf("Incorrect zone value (%s vs %s) for case: %s", - c.availabilityZone, test.zone, test.name) + } else if c.region != test.region { + t.Errorf("Incorrect region value (%s vs %s) for case: %s", + c.region, test.region, test.name) } } } @@ -309,8 +305,8 @@ func (self *FakeEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]* } found := false - for _, instanceId := range request.InstanceIds { - if *instanceId == *instance.InstanceId { + for _, instanceID := range request.InstanceIds { + if *instanceID == *instance.InstanceId { found = true break } @@ -343,16 +339,21 @@ type FakeMetadata struct { func (self *FakeMetadata) GetMetadata(key string) (string, error) { networkInterfacesPrefix := "network/interfaces/macs/" + i := self.aws.selfInstance if key == "placement/availability-zone" { - return self.aws.availabilityZone, nil + az := "" + if i.Placement != nil { + az = aws.StringValue(i.Placement.AvailabilityZone) + } + return az, nil } else if key == "instance-id" { - return self.aws.instanceId, nil + return aws.StringValue(i.InstanceId), nil } else if key == "local-hostname" { - return self.aws.privateDnsName, nil + return aws.StringValue(i.PrivateDnsName), nil } else if key == "local-ipv4" { - return self.aws.internalIP, nil + return aws.StringValue(i.PrivateIpAddress), nil } else if key == "public-ipv4" { - return self.aws.externalIP, nil + return aws.StringValue(i.PublicIpAddress), nil } else if strings.HasPrefix(key, networkInterfacesPrefix) { if key == networkInterfacesPrefix { return strings.Join(self.aws.networkInterfacesMacs, "/\n") + "/\n", nil @@ -499,22 +500,24 @@ func (a *FakeASG) DescribeAutoScalingGroups(*autoscaling.DescribeAutoScalingGrou panic("Not implemented") } -func mockInstancesResp(instances []*ec2.Instance) (*AWSCloud, *FakeAWSServices) { - awsServices := NewFakeAWSServices().withInstances(instances) - return &AWSCloud{ - ec2: awsServices.ec2, - availabilityZone: awsServices.availabilityZone, - metadata: &FakeMetadata{aws: awsServices}, - }, awsServices +func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (*AWSCloud, *FakeAWSServices) { + awsServices := NewFakeAWSServices() + awsServices.instances = instances + awsServices.selfInstance = selfInstance + awsCloud, err := newAWSCloud(nil, awsServices) + if err != nil { + panic(err) + } + return awsCloud, awsServices } -func mockAvailabilityZone(region string, availabilityZone string) *AWSCloud { +func mockAvailabilityZone(availabilityZone string) *AWSCloud { awsServices := NewFakeAWSServices().withAz(availabilityZone) - return &AWSCloud{ - ec2: awsServices.ec2, - availabilityZone: awsServices.availabilityZone, - region: region, + awsCloud, err := newAWSCloud(nil, awsServices) + if err != nil { + panic(err) } + return awsCloud } func TestList(t *testing.T) { @@ -532,6 +535,7 @@ func TestList(t *testing.T) { instance0.Tags = []*ec2.Tag{&tag0} instance0.InstanceId = aws.String("instance0") instance0.PrivateDnsName = aws.String("instance0.ec2.internal") + instance0.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state0 := ec2.InstanceState{ Name: aws.String("running"), } @@ -545,6 +549,7 @@ func TestList(t *testing.T) { instance1.Tags = []*ec2.Tag{&tag1} instance1.InstanceId = aws.String("instance1") instance1.PrivateDnsName = aws.String("instance1.ec2.internal") + instance1.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state1 := ec2.InstanceState{ Name: aws.String("running"), } @@ -558,6 +563,7 @@ func TestList(t *testing.T) { instance2.Tags = []*ec2.Tag{&tag2} instance2.InstanceId = aws.String("instance2") instance2.PrivateDnsName = aws.String("instance2.ec2.internal") + instance2.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state2 := ec2.InstanceState{ Name: aws.String("running"), } @@ -571,13 +577,14 @@ func TestList(t *testing.T) { instance3.Tags = []*ec2.Tag{&tag3} instance3.InstanceId = aws.String("instance3") instance3.PrivateDnsName = aws.String("instance3.ec2.internal") + instance3.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state3 := ec2.InstanceState{ Name: aws.String("running"), } instance3.State = &state3 instances := []*ec2.Instance{&instance0, &instance1, &instance2, &instance3} - aws, _ := mockInstancesResp(instances) + aws, _ := mockInstancesResp(&instance0, instances) table := []struct { input string @@ -616,32 +623,35 @@ func TestNodeAddresses(t *testing.T) { var instance2 ec2.Instance //0 - instance0.InstanceId = aws.String("i-self") + instance0.InstanceId = aws.String("i-0") instance0.PrivateDnsName = aws.String("instance-same.ec2.internal") instance0.PrivateIpAddress = aws.String("192.168.0.1") instance0.PublicIpAddress = aws.String("1.2.3.4") instance0.InstanceType = aws.String("c3.large") + instance0.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state0 := ec2.InstanceState{ Name: aws.String("running"), } instance0.State = &state0 //1 - instance1.InstanceId = aws.String("i-self") + instance1.InstanceId = aws.String("i-1") instance1.PrivateDnsName = aws.String("instance-same.ec2.internal") instance1.PrivateIpAddress = aws.String("192.168.0.2") instance1.InstanceType = aws.String("c3.large") + instance1.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state1 := ec2.InstanceState{ Name: aws.String("running"), } instance1.State = &state1 //2 - instance2.InstanceId = aws.String("i-self") + instance2.InstanceId = aws.String("i-2") instance2.PrivateDnsName = aws.String("instance-other.ec2.internal") instance2.PrivateIpAddress = aws.String("192.168.0.1") instance2.PublicIpAddress = aws.String("1.2.3.4") instance2.InstanceType = aws.String("c3.large") + instance2.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state2 := ec2.InstanceState{ Name: aws.String("running"), } @@ -649,19 +659,19 @@ func TestNodeAddresses(t *testing.T) { instances := []*ec2.Instance{&instance0, &instance1, &instance2} - aws1, _ := mockInstancesResp([]*ec2.Instance{}) + aws1, _ := mockInstancesResp(&instance0, []*ec2.Instance{&instance0}) _, err1 := aws1.NodeAddresses("instance-mismatch.ec2.internal") if err1 == nil { t.Errorf("Should error when no instance found") } - aws2, _ := mockInstancesResp(instances) + aws2, _ := mockInstancesResp(&instance2, instances) _, err2 := aws2.NodeAddresses("instance-same.ec2.internal") if err2 == nil { t.Errorf("Should error when multiple instances found") } - aws3, _ := mockInstancesResp(instances[0:1]) + aws3, _ := mockInstancesResp(&instance0, instances[0:1]) addrs3, err3 := aws3.NodeAddresses("instance-same.ec2.internal") if err3 != nil { t.Errorf("Should not error when instance found") @@ -673,12 +683,12 @@ func TestNodeAddresses(t *testing.T) { testHasNodeAddress(t, addrs3, api.NodeLegacyHostIP, "192.168.0.1") testHasNodeAddress(t, addrs3, api.NodeExternalIP, "1.2.3.4") - aws4, fakeServices := mockInstancesResp([]*ec2.Instance{}) - fakeServices.externalIP = "2.3.4.5" - fakeServices.internalIP = "192.168.0.2" - aws4.selfAWSInstance = &awsInstance{nodeName: fakeServices.instanceId} + // Fetch from metadata + aws4, fakeServices := mockInstancesResp(&instance0, []*ec2.Instance{&instance0}) + fakeServices.selfInstance.PublicIpAddress = aws.String("2.3.4.5") + fakeServices.selfInstance.PrivateIpAddress = aws.String("192.168.0.2") - addrs4, err4 := aws4.NodeAddresses(fakeServices.instanceId) + addrs4, err4 := aws4.NodeAddresses(*instance0.PrivateDnsName) if err4 != nil { t.Errorf("unexpected error: %v", err4) } @@ -687,7 +697,7 @@ func TestNodeAddresses(t *testing.T) { } func TestGetRegion(t *testing.T) { - aws := mockAvailabilityZone("us-west-2", "us-west-2e") + aws := mockAvailabilityZone("us-west-2e") zones, ok := aws.Zones() if !ok { t.Fatalf("Unexpected missing zones impl") @@ -820,8 +830,6 @@ func TestSubnetIDsinVPC(t *testing.T) { return } - vpcID := "vpc-deadbeef" - // test with 3 subnets from 3 different AZs subnets := make(map[int]map[string]string) subnets[0] = make(map[string]string) @@ -842,7 +850,7 @@ func TestSubnetIDsinVPC(t *testing.T) { } awsServices.ec2.RouteTables = constructRouteTables(routeTables) - result, err := c.listPublicSubnetIDsinVPC(vpcID) + result, err := c.listPublicSubnetIDsinVPC() if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -868,7 +876,7 @@ func TestSubnetIDsinVPC(t *testing.T) { // test implicit routing table - when subnets are not explicitly linked to a table they should use main awsServices.ec2.RouteTables = constructRouteTables(map[string]bool{}) - result, err = c.listPublicSubnetIDsinVPC(vpcID) + result, err = c.listPublicSubnetIDsinVPC() if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -900,7 +908,7 @@ func TestSubnetIDsinVPC(t *testing.T) { routeTables["subnet-c0000002"] = true awsServices.ec2.RouteTables = constructRouteTables(routeTables) - result, err = c.listPublicSubnetIDsinVPC(vpcID) + result, err = c.listPublicSubnetIDsinVPC() if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -928,7 +936,7 @@ func TestSubnetIDsinVPC(t *testing.T) { routeTables["subnet-d0000001"] = true routeTables["subnet-d0000002"] = true awsServices.ec2.RouteTables = constructRouteTables(routeTables) - result, err = c.listPublicSubnetIDsinVPC(vpcID) + result, err = c.listPublicSubnetIDsinVPC() if err != nil { t.Errorf("Error listing subnets: %v", err) return