diff --git a/pkg/cloudprovider/cloud.go b/pkg/cloudprovider/cloud.go index f62dca97edf..247cb15ce78 100644 --- a/pkg/cloudprovider/cloud.go +++ b/pkg/cloudprovider/cloud.go @@ -68,6 +68,7 @@ func GetLoadBalancerName(service *v1.Service) string { return ret } +// GetInstanceProviderID builds a ProviderID for a node in a cloud. func GetInstanceProviderID(cloud Interface, nodeName types.NodeName) (string, error) { instances, ok := cloud.Instances() if !ok { diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 8fd1891d04a..eec475e7f7a 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -204,6 +204,7 @@ type Services interface { type EC2 interface { // Query EC2 for instances matching the filter DescribeInstances(request *ec2.DescribeInstancesInput) ([]*ec2.Instance, error) + DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error) // Attach a volume to an instance AttachVolume(*ec2.AttachVolumeInput) (*ec2.VolumeAttachment, error) @@ -608,6 +609,20 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e return results, nil } +// Implementation of EC2.DescribeAddresses +func (s *awsSdkEC2) DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error) { + requestTime := time.Now() + response, err := s.ec2.DescribeAddresses(request) + if err != nil { + recordAwsMetric("describe_address", 0, err) + return nil, fmt.Errorf("error listing AWS addresses: %v", err) + } + + timeTaken := time.Since(requestTime).Seconds() + recordAwsMetric("describe_address", timeTaken, nil) + return response.Addresses, nil +} + // Implements EC2.DescribeSecurityGroups func (s *awsSdkEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGroupsInput) ([]*ec2.SecurityGroup, error) { // Security groups are not paged @@ -1022,7 +1037,45 @@ func (c *Cloud) NodeAddresses(name types.NodeName) ([]v1.NodeAddress, error) { // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) NodeAddressesByProviderID(providerID string) ([]v1.NodeAddress, error) { - return []v1.NodeAddress{}, errors.New("unimplemented") + instanceID, error := instanceIDFromProviderID(providerID) + + if error != nil { + return nil, error + } + + addresses, error := c.describeAddressesByInstanceID(instanceID) + + if error != nil { + return nil, error + } + + instances, error := c.describeInstancesByInstanceID(instanceID) + + if error != nil { + return nil, error + } + + nodeAddresses := []v1.NodeAddress{} + + for _, address := range addresses { + convertedAddress, error := convertAwsAddress(address) + if error != nil { + return nil, error + } + + nodeAddresses = append(nodeAddresses, convertedAddress...) + } + + for _, instance := range instances { + addresses, error := instanceAddresses(instance) + if error != nil { + return nil, error + } + + nodeAddresses = append(nodeAddresses, addresses...) + } + + return nodeAddresses, nil } // ExternalID returns the cloud provider ID of the node with the specified nodeName (deprecated). @@ -1061,7 +1114,19 @@ func (c *Cloud) InstanceID(nodeName types.NodeName) (string, error) { // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) InstanceTypeByProviderID(providerID string) (string, error) { - return "", errors.New("unimplemented") + instanceID, error := instanceIDFromProviderID(providerID) + + if error != nil { + return "", error + } + + instance, error := c.describeInstanceByInstanceID(instanceID) + + if error != nil { + return "", error + } + + return aws.StringValue(instance.InstanceType), nil } // InstanceType returns the type of the node with the specified nodeName. @@ -3342,6 +3407,25 @@ func (c *Cloud) getInstancesByNodeNames(nodeNames []string, states ...string) ([ return instances, nil } +func (c *Cloud) describeInstancesByInstanceID(instanceID string) ([]*ec2.Instance, error) { + filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)} + return c.describeInstances(filters) +} + +func (c *Cloud) describeInstanceByInstanceID(instanceID string) (*ec2.Instance, error) { + filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)} + instances, err := c.describeInstances(filters) + if err != nil { + return nil, err + } + + if len(instances) != 1 { + return nil, fmt.Errorf("expected 1 instance, found %d for instanceID %s", len(instances), instanceID) + } + + return instances[0], nil +} + func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error) { filters = c.tagging.addFilters(filters) request := &ec2.DescribeInstancesInput{ @@ -3362,6 +3446,21 @@ func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error return matches, nil } +func (c *Cloud) describeAddressesByInstanceID(instanceID string) ([]*ec2.Address, error) { + filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)} + params := &ec2.DescribeAddressesInput{ + Filters: filters, + } + + addresses, error := c.ec2.DescribeAddresses(params) + + if error != nil { + return nil, error + } + + return addresses, nil +} + // mapNodeNameToPrivateDNSName maps a k8s NodeName to an AWS Instance PrivateDNSName // This is a simple string cast func mapNodeNameToPrivateDNSName(nodeName types.NodeName) string { @@ -3419,6 +3518,78 @@ func (c *Cloud) getFullInstance(nodeName types.NodeName) (*awsInstance, *ec2.Ins return awsInstance, instance, err } +func instanceAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) { + addresses := []v1.NodeAddress{} + privateDNSName := aws.StringValue(instance.PrivateDnsName) + unsafePrivateIP := aws.StringValue(instance.PrivateIpAddress) + publicDNSName := aws.StringValue(instance.PublicDnsName) + unsafePublicIP := aws.StringValue(instance.PublicIpAddress) + + if privateDNSName != "" { + addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalDNS, Address: privateDNSName}) + } + + if unsafePrivateIP != "" { + ip := net.ParseIP(unsafePrivateIP) + if ip != nil { + addresses = append(addresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) + } else { + return nil, fmt.Errorf("EC2 address had invalid private IP: %s", unsafePrivateIP) + } + } + + if publicDNSName != "" { + addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalDNS, Address: publicDNSName}) + } + + if unsafePublicIP != "" { + ip := net.ParseIP(unsafePublicIP) + if ip != nil { + addresses = append(addresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()}) + } else { + return nil, fmt.Errorf("EC2 address had invalid public IP: %s", unsafePublicIP) + } + } + + return addresses, nil +} + +func convertAwsAddress(address *ec2.Address) ([]v1.NodeAddress, error) { + nodeAddresses := []v1.NodeAddress{} + if aws.StringValue(address.PrivateIpAddress) != "" { + unsafeIP := *address.PrivateIpAddress + ip := net.ParseIP(unsafeIP) + if ip != nil { + nodeAddresses = append(nodeAddresses, v1.NodeAddress{Type: v1.NodeInternalIP, Address: ip.String()}) + } else { + return nil, fmt.Errorf("EC2 address had invalid private IP: %s", unsafeIP) + } + } + + if aws.StringValue(address.PublicIp) != "" { + unsafeIP := *address.PublicIp + ip := net.ParseIP(unsafeIP) + if ip != nil { + nodeAddresses = append(nodeAddresses, v1.NodeAddress{Type: v1.NodeExternalIP, Address: ip.String()}) + } else { + return nil, fmt.Errorf("EC2 address had invalid public IP: %s", unsafeIP) + } + } + + return nodeAddresses, nil +} + +var providerIDRegexp = regexp.MustCompile(`^aws://([^/]+)$`) + +func instanceIDFromProviderID(providerID string) (instanceID string, err error) { + matches := providerIDRegexp.FindStringSubmatch(providerID) + if len(matches) != 2 { + return "", fmt.Errorf("ProviderID \"%s\" didn't match expected format \"aws://InstanceID\"", providerID) + } + + return matches[1], nil +} + func setNodeDisk( nodeDiskMap map[types.NodeName]map[KubernetesVolumeID]bool, volumeID KubernetesVolumeID, diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index a50eac1fc81..f75c1a571c3 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -322,6 +322,12 @@ func (self *FakeEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]* return matches, nil } +func (self *FakeEC2) DescribeAddresses(request *ec2.DescribeAddressesInput) ([]*ec2.Address, error) { + addresses := []*ec2.Address{} + + return addresses, nil +} + type FakeMetadata struct { aws *FakeAWSServices } @@ -1350,3 +1356,37 @@ func TestGetLoadBalancerAdditionalTags(t *testing.T) { } } } + +func TestInstanceIDFromProviderID(t *testing.T) { + testCases := []struct { + providerID string + instanceID string + fail bool + }{ + { + providerID: "aws://i-0194bbdb81a49b169", + instanceID: "i-0194bbdb81a49b169", + fail: false, + }, + { + providerID: "i-0194bbdb81a49b169", + instanceID: "", + fail: true, + }, + } + + for _, test := range testCases { + instanceID, err := instanceIDFromProviderID(test.providerID) + if (err != nil) != test.fail { + t.Errorf("%s yielded `err != nil` as %t. expected %t", test.providerID, (err != nil), test.fail) + } + + if test.fail { + continue + } + + if instanceID != test.instanceID { + t.Errorf("%s yielded %s. expected %s", test.providerID, instanceID, test.instanceID) + } + } +}