diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 795f8a7fab7..1574d8a3b7f 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -194,17 +194,18 @@ type InstanceGroupInfo interface { // AWSCloud is an implementation of Interface, LoadBalancer and Instances for Amazon Web Services. type AWSCloud struct { - ec2 EC2 - elb ELB - asg ASG - metadata EC2Metadata - cfg *AWSCloudConfig - availabilityZone string - region string + ec2 EC2 + elb ELB + asg ASG + metadata EC2Metadata + cfg *AWSCloudConfig + region string + vpcID string filterTags map[string]string // The AWS instance that we are running on + // Note that we cache some state in awsInstance (mountpoints), so we must preserve the instance selfAWSInstance *awsInstance mutex sync.Mutex @@ -368,12 +369,8 @@ func (self *AWSCloud) AddSSHKeyToAllInstances(user string, keyData []byte) error return errors.New("unimplemented") } -func (a *AWSCloud) CurrentNodeName(hostname string) (string, error) { - selfInstance, err := a.getSelfAWSInstance() - if err != nil { - return "", err - } - return selfInstance.nodeName, nil +func (c *AWSCloud) CurrentNodeName(hostname string) (string, error) { + return c.selfAWSInstance.nodeName, nil } // Implementation of EC2.Instances @@ -628,28 +625,32 @@ func newAWSCloud(config io.Reader, awsServices AWSServices) (*AWSCloud, error) { } awsCloud := &AWSCloud{ - ec2: ec2, - elb: elb, - asg: asg, - metadata: metadata, - cfg: cfg, - region: regionName, - availabilityZone: zone, + ec2: ec2, + elb: elb, + asg: asg, + metadata: metadata, + cfg: cfg, + region: regionName, } + selfAWSInstance, err := awsCloud.buildSelfAWSInstance() + if err != nil { + return nil, err + } + + awsCloud.selfAWSInstance = selfAWSInstance + awsCloud.vpcID = selfAWSInstance.vpcID + filterTags := map[string]string{} if cfg.Global.KubernetesClusterTag != "" { filterTags[TagNameKubernetesCluster] = cfg.Global.KubernetesClusterTag } else { - selfInstance, err := awsCloud.getSelfAWSInstance() + // TODO: Clean up double-API query + info, err := selfAWSInstance.describeInstance() if err != nil { return nil, err } - selfInstanceInfo, err := selfInstance.getInfo() - if err != nil { - return nil, err - } - for _, tag := range selfInstanceInfo.Tags { + for _, tag := range info.Tags { if orEmpty(tag.Key) == TagNameKubernetesCluster { filterTags[TagNameKubernetesCluster] = orEmpty(tag.Value) } @@ -710,15 +711,11 @@ func (aws *AWSCloud) Routes() (cloudprovider.Routes, bool) { } // NodeAddresses is an implementation of Instances.NodeAddresses. -func (aws *AWSCloud) NodeAddresses(name string) ([]api.NodeAddress, error) { - self, err := aws.getSelfAWSInstance() - if err != nil { - return nil, err - } - if self.nodeName == name || len(name) == 0 { +func (c *AWSCloud) NodeAddresses(name string) ([]api.NodeAddress, error) { + if c.selfAWSInstance.nodeName == name || len(name) == 0 { addresses := []api.NodeAddress{} - internalIP, err := aws.metadata.GetMetadata("local-ipv4") + internalIP, err := c.metadata.GetMetadata("local-ipv4") if err != nil { return nil, err } @@ -726,7 +723,7 @@ func (aws *AWSCloud) NodeAddresses(name string) ([]api.NodeAddress, error) { // Legacy compatibility: the private ip was the legacy host ip addresses = append(addresses, api.NodeAddress{Type: api.NodeLegacyHostIP, Address: internalIP}) - externalIP, err := aws.metadata.GetMetadata("public-ipv4") + externalIP, err := c.metadata.GetMetadata("public-ipv4") if err != nil { //TODO: It would be nice to be able to determine the reason for the failure, // but the AWS client masks all failures with the same error description. @@ -737,7 +734,7 @@ func (aws *AWSCloud) NodeAddresses(name string) ([]api.NodeAddress, error) { return addresses, nil } - instance, err := aws.getInstanceByNodeName(name) + instance, err := c.getInstanceByNodeName(name) if err != nil { return nil, err } @@ -770,19 +767,14 @@ func (aws *AWSCloud) NodeAddresses(name string) ([]api.NodeAddress, error) { } // ExternalID returns the cloud provider ID of the specified instance (deprecated). -func (aws *AWSCloud) ExternalID(name string) (string, error) { - awsInstance, err := aws.getSelfAWSInstance() - if err != nil { - return "", err - } - - if awsInstance.nodeName == name { +func (c *AWSCloud) ExternalID(name string) (string, error) { + if c.selfAWSInstance.nodeName == name { // We assume that if this is run on the instance itself, the instance exists and is alive - return awsInstance.awsID, nil + return c.selfAWSInstance.awsID, nil } else { // We must verify that the instance still exists // Note that if the instance does not exist or is no longer running, we must return ("", cloudprovider.InstanceNotFound) - instance, err := aws.findInstanceByNodeName(name) + instance, err := c.findInstanceByNodeName(name) if err != nil { return "", err } @@ -794,18 +786,13 @@ func (aws *AWSCloud) ExternalID(name string) (string, error) { } // InstanceID returns the cloud provider ID of the specified instance. -func (aws *AWSCloud) InstanceID(name string) (string, error) { - awsInstance, err := aws.getSelfAWSInstance() - if err != nil { - return "", err - } - +func (c *AWSCloud) InstanceID(name string) (string, error) { // In the future it is possible to also return an endpoint as: // // - if awsInstance.nodeName == name { - return "/" + awsInstance.availabilityZone + "/" + awsInstance.awsID, nil + if c.selfAWSInstance.nodeName == name { + return "/" + c.selfAWSInstance.availabilityZone + "/" + c.selfAWSInstance.awsID, nil } else { - inst, err := aws.getInstanceByNodeName(name) + inst, err := c.getInstanceByNodeName(name) if err != nil { return "", err } @@ -814,16 +801,11 @@ func (aws *AWSCloud) InstanceID(name string) (string, error) { } // InstanceType returns the type of the specified instance. -func (aws *AWSCloud) InstanceType(name string) (string, error) { - awsInstance, err := aws.getSelfAWSInstance() - if err != nil { - return "", err - } - - if awsInstance.nodeName == name { - return awsInstance.instanceType, nil +func (c *AWSCloud) InstanceType(name string) (string, error) { + if c.selfAWSInstance.nodeName == name { + return c.selfAWSInstance.instanceType, nil } else { - inst, err := aws.getInstanceByNodeName(name) + inst, err := c.getInstanceByNodeName(name) if err != nil { return "", err } @@ -891,10 +873,10 @@ func (aws *AWSCloud) List(filter string) ([]string, error) { } // GetZone implements Zones.GetZone -func (self *AWSCloud) GetZone() (cloudprovider.Zone, error) { +func (c *AWSCloud) GetZone() (cloudprovider.Zone, error) { return cloudprovider.Zone{ - FailureDomain: self.availabilityZone, - Region: self.region, + FailureDomain: c.selfAWSInstance.availabilityZone, + Region: c.region, }, nil } @@ -929,6 +911,12 @@ type awsInstance struct { // availability zone the instance resides in availabilityZone string + // ID of VPC the instance resides in + vpcID string + + // ID of subnet the instance resides in + subnetID string + // instance type instanceType string @@ -939,8 +927,21 @@ type awsInstance struct { deviceMappings map[mountDevice]string } -func newAWSInstance(ec2 EC2, awsID, nodeName, availabilityZone, instanceType string) *awsInstance { - self := &awsInstance{ec2: ec2, awsID: awsID, nodeName: nodeName, availabilityZone: availabilityZone, instanceType: instanceType} +// newAWSInstance creates a new awsInstance object +func newAWSInstance(ec2Service EC2, instance *ec2.Instance) *awsInstance { + az := "" + if instance.Placement != nil { + az = aws.StringValue(instance.Placement.AvailabilityZone) + } + self := &awsInstance{ + ec2: ec2Service, + awsID: aws.StringValue(instance.InstanceId), + nodeName: aws.StringValue(instance.PrivateDnsName), + availabilityZone: az, + instanceType: aws.StringValue(instance.InstanceType), + vpcID: aws.StringValue(instance.VpcId), + subnetID: aws.StringValue(instance.SubnetId), + } // We lazy-init deviceMappings self.deviceMappings = nil @@ -956,7 +957,7 @@ func (self *awsInstance) getInstanceType() *awsInstanceType { } // Gets the full information about this instance from the EC2 API -func (self *awsInstance) getInfo() (*ec2.Instance, error) { +func (self *awsInstance) describeInstance() (*ec2.Instance, error) { instanceID := self.awsID request := &ec2.DescribeInstancesInput{ InstanceIds: []*string{&instanceID}, @@ -992,7 +993,7 @@ func (self *awsInstance) getMountDevice(volumeID string, assign bool) (assigned // We cache both for efficiency and correctness if self.deviceMappings == nil { - info, err := self.getInfo() + info, err := self.describeInstance() if err != nil { return "", false, err } @@ -1073,15 +1074,22 @@ type awsDisk struct { name string // id in AWS awsID string - // az which holds the volume - az string } func newAWSDisk(aws *AWSCloud, name string) (*awsDisk, error) { - if !strings.HasPrefix(name, "aws://") { - name = "aws://" + aws.availabilityZone + "/" + name - } // name looks like aws://availability-zone/id + + // The original idea of the URL-style name was to put the AZ into the + // host, so we could find the AZ immediately from the name without + // querying the API. But it turns out we don't actually need it for + // Ubernetes-Lite, as we put the AZ into the labels on the PV instead. + // However, if in future we want to support Ubernetes-Lite + // volume-awareness without using PersistentVolumes, we likely will + // want the AZ in the host. + + if !strings.HasPrefix(name, "aws://") { + name = "aws://" + "" + "/" + name + } url, err := url.Parse(name) if err != nil { // TODO: Maybe we should pass a URL into the Volume functions @@ -1100,19 +1108,13 @@ func newAWSDisk(aws *AWSCloud, name string) (*awsDisk, error) { if strings.Contains(awsID, "/") || !strings.HasPrefix(awsID, "vol-") { return nil, fmt.Errorf("Invalid format for AWS volume (%s)", name) } - az := url.Host - // TODO: Better validation? - // TODO: Default to our AZ? Look it up? - // TODO: Should this be a region or an AZ? - if az == "" { - return nil, fmt.Errorf("Invalid format for AWS volume (%s)", name) - } - disk := &awsDisk{ec2: aws.ec2, name: name, awsID: awsID, az: az} + + disk := &awsDisk{ec2: aws.ec2, name: name, awsID: awsID} return disk, nil } // Gets the full information about this volume from the EC2 API -func (self *awsDisk) getInfo() (*ec2.Volume, error) { +func (self *awsDisk) describeVolume() (*ec2.Volume, error) { volumeID := self.awsID request := &ec2.DescribeVolumesInput{ @@ -1138,7 +1140,7 @@ func (self *awsDisk) waitForAttachmentStatus(status string) error { maxAttempts := 60 for { - info, err := self.getInfo() + info, err := self.describeVolume() if err != nil { return err } @@ -1191,59 +1193,44 @@ func (self *awsDisk) deleteVolume() (bool, error) { return true, nil } -// Gets the awsInstance for the EC2 instance on which we are running -// may return nil in case of error -func (s *AWSCloud) getSelfAWSInstance() (*awsInstance, error) { - // Note that we cache some state in awsInstance (mountpoints), so we must preserve the instance - - s.mutex.Lock() - defer s.mutex.Unlock() - - i := s.selfAWSInstance - if i == nil { - instanceId, err := s.metadata.GetMetadata("instance-id") - if err != nil { - return nil, fmt.Errorf("error fetching instance-id from ec2 metadata service: %v", err) - } - // privateDnsName, err := s.metadata.GetMetadata("local-hostname") - // See #11543 - need to use ec2 API to get the privateDnsName in case of private dns zone e.g. mydomain.io - instance, err := s.getInstanceByID(instanceId) - if err != nil { - return nil, fmt.Errorf("error finding instance %s: %v", instanceId, err) - } - privateDnsName := aws.StringValue(instance.PrivateDnsName) - availabilityZone, err := getAvailabilityZone(s.metadata) - if err != nil { - return nil, fmt.Errorf("error fetching availability zone from ec2 metadata service: %v", err) - } - instanceType, err := getInstanceType(s.metadata) - if err != nil { - return nil, fmt.Errorf("error fetching instance type from ec2 metadata service: %v", err) - } - - i = newAWSInstance(s.ec2, instanceId, privateDnsName, availabilityZone, instanceType) - s.selfAWSInstance = i +// Builds the awsInstance for the EC2 instance on which we are running. +// This is called when the AWSCloud is initialized, and should not be called otherwise (because the awsInstance for the local instance is a singleton with drive mapping state) +func (c *AWSCloud) buildSelfAWSInstance() (*awsInstance, error) { + if c.selfAWSInstance != nil { + panic("do not call buildSelfAWSInstance directly") + } + instanceId, err := c.metadata.GetMetadata("instance-id") + if err != nil { + return nil, fmt.Errorf("error fetching instance-id from ec2 metadata service: %v", err) } - return i, nil + // We want to fetch the hostname via the EC2 metadata service + // (`GetMetadata("local-hostname")`): But see #11543 - we need to use + // the EC2 API to get the privateDnsName in case of a private DNS zone + // e.g. mydomain.io, because the metadata service returns the wrong + // hostname. Once we're doing that, we might as well get all our + // information from the instance returned by the EC2 API - it is a + // single API call to get all the information, and it means we don't + // have two code paths. + instance, err := c.getInstanceByID(instanceId) + if err != nil { + return nil, fmt.Errorf("error finding instance %s: %v", instanceId, err) + } + return newAWSInstance(c.ec2, instance), nil } // Gets the awsInstance with node-name nodeName, or the 'self' instance if nodeName == "" -func (aws *AWSCloud) getAwsInstance(nodeName string) (*awsInstance, error) { +func (c *AWSCloud) getAwsInstance(nodeName string) (*awsInstance, error) { var awsInstance *awsInstance - var err error if nodeName == "" { - awsInstance, err = aws.getSelfAWSInstance() - if err != nil { - return nil, fmt.Errorf("error getting self-instance: %v", err) - } + awsInstance = c.selfAWSInstance } else { - instance, err := aws.getInstanceByNodeName(nodeName) + instance, err := c.getInstanceByNodeName(nodeName) if err != nil { return nil, fmt.Errorf("error finding instance %s: %v", nodeName, err) } - awsInstance = newAWSInstance(aws.ec2, orEmpty(instance.InstanceId), orEmpty(instance.PrivateDnsName), orEmpty(instance.Placement.AvailabilityZone), orEmpty(instance.InstanceType)) + awsInstance = newAWSInstance(c.ec2, instance) } return awsInstance, nil @@ -1382,10 +1369,13 @@ func (aws *AWSCloud) DetachDisk(diskName string, instanceName string) (string, e // Implements Volumes.CreateVolume func (s *AWSCloud) CreateDisk(volumeOptions *VolumeOptions) (string, error) { - // TODO: Should we tag this with the cluster id (so it gets deleted when the cluster does?) + // Default to creating in the current zone + // TODO: Spread across zones? + createAZ := s.selfAWSInstance.availabilityZone + // TODO: Should we tag this with the cluster id (so it gets deleted when the cluster does?) request := &ec2.CreateVolumeInput{} - request.AvailabilityZone = &s.availabilityZone + request.AvailabilityZone = &createAZ volSize := int64(volumeOptions.CapacityGB) request.Size = &volSize request.VolumeType = aws.String(DefaultVolumeType) @@ -1415,8 +1405,8 @@ func (s *AWSCloud) CreateDisk(volumeOptions *VolumeOptions) (string, error) { } // Implements Volumes.DeleteDisk -func (aws *AWSCloud) DeleteDisk(volumeName string) (bool, error) { - awsDisk, err := newAWSDisk(aws, volumeName) +func (c *AWSCloud) DeleteDisk(volumeName string) (bool, error) { + awsDisk, err := newAWSDisk(c, volumeName) if err != nil { return false, err } @@ -1429,7 +1419,7 @@ func (c *AWSCloud) GetVolumeLabels(volumeName string) (map[string]string, error) if err != nil { return nil, err } - info, err := awsDisk.getInfo() + info, err := awsDisk.describeVolume() if err != nil { return nil, err } @@ -1810,7 +1800,7 @@ func (s *AWSCloud) ensureClusterTags(resourceID string, tags []*ec2.Tag) error { // Makes sure the security group exists. // For multi-cluster isolation, name must be globally unique, for example derived from the service UUID. // Returns the security group id or error -func (s *AWSCloud) ensureSecurityGroup(name string, description string, vpcID string) (string, error) { +func (s *AWSCloud) ensureSecurityGroup(name string, description string) (string, error) { groupID := "" attempt := 0 for { @@ -1819,7 +1809,7 @@ func (s *AWSCloud) ensureSecurityGroup(name string, description string, vpcID st request := &ec2.DescribeSecurityGroupsInput{} filters := []*ec2.Filter{ newEc2Filter("group-name", name), - newEc2Filter("vpc-id", vpcID), + newEc2Filter("vpc-id", s.vpcID), } // Note that we do _not_ add our tag filters; group-name + vpc-id is the EC2 primary key. // However, we do check that it matches our tags. @@ -1846,7 +1836,7 @@ func (s *AWSCloud) ensureSecurityGroup(name string, description string, vpcID st } createRequest := &ec2.CreateSecurityGroupInput{} - createRequest.VpcId = &vpcID + createRequest.VpcId = &s.vpcID createRequest.GroupName = &name createRequest.Description = &description @@ -1928,9 +1918,9 @@ func (s *AWSCloud) createTags(resourceID string, tags map[string]string) error { } } -func (s *AWSCloud) listPublicSubnetIDsinVPC(vpcId string) ([]string, error) { +func (s *AWSCloud) listPublicSubnetIDsinVPC() ([]string, error) { sRequest := &ec2.DescribeSubnetsInput{} - vpcIdFilter := newEc2Filter("vpc-id", vpcId) + vpcIdFilter := newEc2Filter("vpc-id", s.vpcID) var filters []*ec2.Filter filters = append(filters, vpcIdFilter) filters = s.addFilters(filters) @@ -2061,14 +2051,8 @@ func (s *AWSCloud) EnsureLoadBalancer(name, region string, publicIP net.IP, port return nil, err } - vpcId, err := s.findVPCID() - if err != nil { - glog.Error("Error finding VPC", err) - return nil, err - } - // Construct list of configured subnets - subnetIDs, err := s.listPublicSubnetIDsinVPC(vpcId) + subnetIDs, err := s.listPublicSubnetIDsinVPC() if err != nil { glog.Error("Error listing subnets in VPC: ", err) return nil, err @@ -2079,7 +2063,7 @@ func (s *AWSCloud) EnsureLoadBalancer(name, region string, publicIP net.IP, port { sgName := "k8s-elb-" + name sgDescription := fmt.Sprintf("Security group for Kubernetes ELB %s (%v)", name, serviceName) - securityGroupID, err = s.ensureSecurityGroup(sgName, sgDescription, vpcId) + securityGroupID, err = s.ensureSecurityGroup(sgName, sgDescription) if err != nil { glog.Error("Error creating load balancer security group: ", err) return nil, err diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 1004714c6b0..d92931f0459 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -110,14 +110,11 @@ func TestReadAWSCloudConfig(t *testing.T) { } type FakeAWSServices struct { - availabilityZone string + region string instances []*ec2.Instance - instanceId string - privateDnsName string + selfInstance *ec2.Instance networkInterfacesMacs []string networkInterfacesVpcIDs []string - internalIP string - externalIP string ec2 *FakeEC2 elb *FakeELB @@ -127,7 +124,7 @@ type FakeAWSServices struct { func NewFakeAWSServices() *FakeAWSServices { s := &FakeAWSServices{} - s.availabilityZone = "us-east-1a" + s.region = "us-east-1" s.ec2 = &FakeEC2{aws: s} s.elb = &FakeELB{aws: s} s.asg = &FakeASG{aws: s} @@ -136,14 +133,16 @@ func NewFakeAWSServices() *FakeAWSServices { s.networkInterfacesMacs = []string{"aa:bb:cc:dd:ee:00", "aa:bb:cc:dd:ee:01"} s.networkInterfacesVpcIDs = []string{"vpc-mac0", "vpc-mac1"} - s.instanceId = "i-self" - s.privateDnsName = "ip-172-20-0-100.ec2.internal" - s.internalIP = "192.168.0.1" - s.externalIP = "1.2.3.4" - var selfInstance ec2.Instance - selfInstance.InstanceId = &s.instanceId - selfInstance.PrivateDnsName = &s.privateDnsName - s.instances = []*ec2.Instance{&selfInstance} + selfInstance := &ec2.Instance{} + selfInstance.InstanceId = aws.String("i-self") + selfInstance.Placement = &ec2.Placement{ + AvailabilityZone: aws.String("us-east-1a"), + } + selfInstance.PrivateDnsName = aws.String("ip-172-20-0-100.ec2.internal") + selfInstance.PrivateIpAddress = aws.String("192.168.0.1") + selfInstance.PublicIpAddress = aws.String("1.2.3.4") + s.selfInstance = selfInstance + s.instances = []*ec2.Instance{selfInstance} var tag ec2.Tag tag.Key = aws.String(TagNameKubernetesCluster) @@ -154,12 +153,10 @@ func NewFakeAWSServices() *FakeAWSServices { } func (s *FakeAWSServices) withAz(az string) *FakeAWSServices { - s.availabilityZone = az - return s -} - -func (s *FakeAWSServices) withInstances(instances []*ec2.Instance) *FakeAWSServices { - s.instances = instances + if s.selfInstance.Placement == nil { + s.selfInstance.Placement = &ec2.Placement{} + } + s.selfInstance.Placement.AvailabilityZone = aws.String(az) return s } @@ -205,7 +202,7 @@ func TestNewAWSCloud(t *testing.T) { awsServices AWSServices expectError bool - zone string + region string }{ { "No config reader", @@ -220,14 +217,13 @@ func TestNewAWSCloud(t *testing.T) { { "Config specifies valid zone", strings.NewReader("[global]\nzone = eu-west-1a"), NewFakeAWSServices(), - false, "eu-west-1a", + false, "eu-west-1", }, { "Gets zone from metadata when not in config", - strings.NewReader("[global]\n"), NewFakeAWSServices(), - false, "us-east-1a", + false, "us-east-1", }, { "No zone in config or metadata", @@ -247,9 +243,9 @@ func TestNewAWSCloud(t *testing.T) { } else { if err != nil { t.Errorf("Should succeed for case: %s, got %v", test.name, err) - } else if c.availabilityZone != test.zone { - t.Errorf("Incorrect zone value (%s vs %s) for case: %s", - c.availabilityZone, test.zone, test.name) + } else if c.region != test.region { + t.Errorf("Incorrect region value (%s vs %s) for case: %s", + c.region, test.region, test.name) } } } @@ -309,8 +305,8 @@ func (self *FakeEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]* } found := false - for _, instanceId := range request.InstanceIds { - if *instanceId == *instance.InstanceId { + for _, instanceID := range request.InstanceIds { + if *instanceID == *instance.InstanceId { found = true break } @@ -343,16 +339,21 @@ type FakeMetadata struct { func (self *FakeMetadata) GetMetadata(key string) (string, error) { networkInterfacesPrefix := "network/interfaces/macs/" + i := self.aws.selfInstance if key == "placement/availability-zone" { - return self.aws.availabilityZone, nil + az := "" + if i.Placement != nil { + az = aws.StringValue(i.Placement.AvailabilityZone) + } + return az, nil } else if key == "instance-id" { - return self.aws.instanceId, nil + return aws.StringValue(i.InstanceId), nil } else if key == "local-hostname" { - return self.aws.privateDnsName, nil + return aws.StringValue(i.PrivateDnsName), nil } else if key == "local-ipv4" { - return self.aws.internalIP, nil + return aws.StringValue(i.PrivateIpAddress), nil } else if key == "public-ipv4" { - return self.aws.externalIP, nil + return aws.StringValue(i.PublicIpAddress), nil } else if strings.HasPrefix(key, networkInterfacesPrefix) { if key == networkInterfacesPrefix { return strings.Join(self.aws.networkInterfacesMacs, "/\n") + "/\n", nil @@ -499,22 +500,24 @@ func (a *FakeASG) DescribeAutoScalingGroups(*autoscaling.DescribeAutoScalingGrou panic("Not implemented") } -func mockInstancesResp(instances []*ec2.Instance) (*AWSCloud, *FakeAWSServices) { - awsServices := NewFakeAWSServices().withInstances(instances) - return &AWSCloud{ - ec2: awsServices.ec2, - availabilityZone: awsServices.availabilityZone, - metadata: &FakeMetadata{aws: awsServices}, - }, awsServices +func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (*AWSCloud, *FakeAWSServices) { + awsServices := NewFakeAWSServices() + awsServices.instances = instances + awsServices.selfInstance = selfInstance + awsCloud, err := newAWSCloud(nil, awsServices) + if err != nil { + panic(err) + } + return awsCloud, awsServices } -func mockAvailabilityZone(region string, availabilityZone string) *AWSCloud { +func mockAvailabilityZone(availabilityZone string) *AWSCloud { awsServices := NewFakeAWSServices().withAz(availabilityZone) - return &AWSCloud{ - ec2: awsServices.ec2, - availabilityZone: awsServices.availabilityZone, - region: region, + awsCloud, err := newAWSCloud(nil, awsServices) + if err != nil { + panic(err) } + return awsCloud } func TestList(t *testing.T) { @@ -532,6 +535,7 @@ func TestList(t *testing.T) { instance0.Tags = []*ec2.Tag{&tag0} instance0.InstanceId = aws.String("instance0") instance0.PrivateDnsName = aws.String("instance0.ec2.internal") + instance0.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state0 := ec2.InstanceState{ Name: aws.String("running"), } @@ -545,6 +549,7 @@ func TestList(t *testing.T) { instance1.Tags = []*ec2.Tag{&tag1} instance1.InstanceId = aws.String("instance1") instance1.PrivateDnsName = aws.String("instance1.ec2.internal") + instance1.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state1 := ec2.InstanceState{ Name: aws.String("running"), } @@ -558,6 +563,7 @@ func TestList(t *testing.T) { instance2.Tags = []*ec2.Tag{&tag2} instance2.InstanceId = aws.String("instance2") instance2.PrivateDnsName = aws.String("instance2.ec2.internal") + instance2.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state2 := ec2.InstanceState{ Name: aws.String("running"), } @@ -571,13 +577,14 @@ func TestList(t *testing.T) { instance3.Tags = []*ec2.Tag{&tag3} instance3.InstanceId = aws.String("instance3") instance3.PrivateDnsName = aws.String("instance3.ec2.internal") + instance3.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state3 := ec2.InstanceState{ Name: aws.String("running"), } instance3.State = &state3 instances := []*ec2.Instance{&instance0, &instance1, &instance2, &instance3} - aws, _ := mockInstancesResp(instances) + aws, _ := mockInstancesResp(&instance0, instances) table := []struct { input string @@ -616,32 +623,35 @@ func TestNodeAddresses(t *testing.T) { var instance2 ec2.Instance //0 - instance0.InstanceId = aws.String("i-self") + instance0.InstanceId = aws.String("i-0") instance0.PrivateDnsName = aws.String("instance-same.ec2.internal") instance0.PrivateIpAddress = aws.String("192.168.0.1") instance0.PublicIpAddress = aws.String("1.2.3.4") instance0.InstanceType = aws.String("c3.large") + instance0.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state0 := ec2.InstanceState{ Name: aws.String("running"), } instance0.State = &state0 //1 - instance1.InstanceId = aws.String("i-self") + instance1.InstanceId = aws.String("i-1") instance1.PrivateDnsName = aws.String("instance-same.ec2.internal") instance1.PrivateIpAddress = aws.String("192.168.0.2") instance1.InstanceType = aws.String("c3.large") + instance1.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state1 := ec2.InstanceState{ Name: aws.String("running"), } instance1.State = &state1 //2 - instance2.InstanceId = aws.String("i-self") + instance2.InstanceId = aws.String("i-2") instance2.PrivateDnsName = aws.String("instance-other.ec2.internal") instance2.PrivateIpAddress = aws.String("192.168.0.1") instance2.PublicIpAddress = aws.String("1.2.3.4") instance2.InstanceType = aws.String("c3.large") + instance2.Placement = &ec2.Placement{AvailabilityZone: aws.String("us-east-1a")} state2 := ec2.InstanceState{ Name: aws.String("running"), } @@ -649,19 +659,19 @@ func TestNodeAddresses(t *testing.T) { instances := []*ec2.Instance{&instance0, &instance1, &instance2} - aws1, _ := mockInstancesResp([]*ec2.Instance{}) + aws1, _ := mockInstancesResp(&instance0, []*ec2.Instance{&instance0}) _, err1 := aws1.NodeAddresses("instance-mismatch.ec2.internal") if err1 == nil { t.Errorf("Should error when no instance found") } - aws2, _ := mockInstancesResp(instances) + aws2, _ := mockInstancesResp(&instance2, instances) _, err2 := aws2.NodeAddresses("instance-same.ec2.internal") if err2 == nil { t.Errorf("Should error when multiple instances found") } - aws3, _ := mockInstancesResp(instances[0:1]) + aws3, _ := mockInstancesResp(&instance0, instances[0:1]) addrs3, err3 := aws3.NodeAddresses("instance-same.ec2.internal") if err3 != nil { t.Errorf("Should not error when instance found") @@ -673,12 +683,12 @@ func TestNodeAddresses(t *testing.T) { testHasNodeAddress(t, addrs3, api.NodeLegacyHostIP, "192.168.0.1") testHasNodeAddress(t, addrs3, api.NodeExternalIP, "1.2.3.4") - aws4, fakeServices := mockInstancesResp([]*ec2.Instance{}) - fakeServices.externalIP = "2.3.4.5" - fakeServices.internalIP = "192.168.0.2" - aws4.selfAWSInstance = &awsInstance{nodeName: fakeServices.instanceId} + // Fetch from metadata + aws4, fakeServices := mockInstancesResp(&instance0, []*ec2.Instance{&instance0}) + fakeServices.selfInstance.PublicIpAddress = aws.String("2.3.4.5") + fakeServices.selfInstance.PrivateIpAddress = aws.String("192.168.0.2") - addrs4, err4 := aws4.NodeAddresses(fakeServices.instanceId) + addrs4, err4 := aws4.NodeAddresses(*instance0.PrivateDnsName) if err4 != nil { t.Errorf("unexpected error: %v", err4) } @@ -687,7 +697,7 @@ func TestNodeAddresses(t *testing.T) { } func TestGetRegion(t *testing.T) { - aws := mockAvailabilityZone("us-west-2", "us-west-2e") + aws := mockAvailabilityZone("us-west-2e") zones, ok := aws.Zones() if !ok { t.Fatalf("Unexpected missing zones impl") @@ -820,8 +830,6 @@ func TestSubnetIDsinVPC(t *testing.T) { return } - vpcID := "vpc-deadbeef" - // test with 3 subnets from 3 different AZs subnets := make(map[int]map[string]string) subnets[0] = make(map[string]string) @@ -842,7 +850,7 @@ func TestSubnetIDsinVPC(t *testing.T) { } awsServices.ec2.RouteTables = constructRouteTables(routeTables) - result, err := c.listPublicSubnetIDsinVPC(vpcID) + result, err := c.listPublicSubnetIDsinVPC() if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -868,7 +876,7 @@ func TestSubnetIDsinVPC(t *testing.T) { // test implicit routing table - when subnets are not explicitly linked to a table they should use main awsServices.ec2.RouteTables = constructRouteTables(map[string]bool{}) - result, err = c.listPublicSubnetIDsinVPC(vpcID) + result, err = c.listPublicSubnetIDsinVPC() if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -900,7 +908,7 @@ func TestSubnetIDsinVPC(t *testing.T) { routeTables["subnet-c0000002"] = true awsServices.ec2.RouteTables = constructRouteTables(routeTables) - result, err = c.listPublicSubnetIDsinVPC(vpcID) + result, err = c.listPublicSubnetIDsinVPC() if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -928,7 +936,7 @@ func TestSubnetIDsinVPC(t *testing.T) { routeTables["subnet-d0000001"] = true routeTables["subnet-d0000002"] = true awsServices.ec2.RouteTables = constructRouteTables(routeTables) - result, err = c.listPublicSubnetIDsinVPC(vpcID) + result, err = c.listPublicSubnetIDsinVPC() if err != nil { t.Errorf("Error listing subnets: %v", err) return @@ -1177,7 +1185,7 @@ func TestGetVolumeLabels(t *testing.T) { awsServices.ec2.On("DescribeVolumes", expectedVolumeRequest).Return([]*ec2.Volume{ { VolumeId: volumeId, - AvailabilityZone: &awsServices.availabilityZone, + AvailabilityZone: aws.String("us-east-1a"), }, })