diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 67a5636923c..62e6392ca3a 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -1111,7 +1111,16 @@ func (c *Cloud) InstanceID(nodeName types.NodeName) (string, error) { // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) InstanceTypeByProviderID(providerID string) (string, error) { - return "", errors.New("unimplemented") + // In AWS, we're using the instanceID as the providerID. + instanceID := providerID + + instance, error := c.describeInstanceByInstanceID(instanceID) + + if error != nil { + return "", error + } + + return aws.StringValue(instance.InstanceType), nil } // InstanceType returns the type of the node with the specified nodeName. @@ -3397,6 +3406,20 @@ func (c *Cloud) describeInstancesByInstanceID(instanceID string) ([]*ec2.Instanc return c.describeInstances(filters) } +func (c *Cloud) describeInstanceByInstanceID(instanceID string) (*ec2.Instance, error) { + filters := []*ec2.Filter{newEc2Filter("instance-id", instanceID)} + instances, err := c.describeInstances(filters) + if err != nil { + return nil, err + } + + if len(instances) != 1 { + return nil, fmt.Errorf("expected 1 instance, found %d for instanceID %s", len(instances), instanceID) + } + + return instances[0], nil +} + func (c *Cloud) describeInstances(filters []*ec2.Filter) ([]*ec2.Instance, error) { filters = c.tagging.addFilters(filters) request := &ec2.DescribeInstancesInput{