diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index d4d13ebc91b..943346c40ba 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -20,7 +20,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/url" @@ -34,6 +33,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" + "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/elb" @@ -61,7 +61,7 @@ type AWSServices interface { Compute(region string) (EC2, error) LoadBalancing(region string) (ELB, error) Autoscaling(region string) (ASG, error) - Metadata() AWSMetadata + Metadata() (EC2Metadata, error) } // TODO: Should we rename this to AWS (EBS & ELB are not technically part of EC2) @@ -129,9 +129,9 @@ type ASG interface { } // Abstraction over the AWS metadata service -type AWSMetadata interface { +type EC2Metadata interface { // Query the EC2 metadata service (used to discover instance-id etc) - GetMetaData(key string) ([]byte, error) + GetMetadata(path string) (string, error) } type VolumeOptions struct { @@ -175,6 +175,7 @@ type AWSCloud struct { ec2 EC2 elb ELB asg ASG + metadata EC2Metadata cfg *AWSCloudConfig availabilityZone string region string @@ -232,8 +233,9 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { return client, nil } -func (p *awsSDKProvider) Metadata() AWSMetadata { - return &awsSdkMetadata{} +func (p *awsSDKProvider) Metadata() (EC2Metadata, error) { + client := ec2metadata.New(nil) + return client, nil } func stringPointerArray(orig []string) []*string { @@ -307,34 +309,16 @@ func (self *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([ } type awsSdkMetadata struct { + metadata *ec2metadata.Client } var metadataClient = http.Client{ Timeout: time.Second * 10, } -// Implements AWSMetadata.GetMetaData -func (self *awsSdkMetadata) GetMetaData(key string) ([]byte, error) { - // TODO Get an implementation of this merged into aws-sdk-go - url := "http://169.254.169.254/latest/meta-data/" + key - - res, err := metadataClient.Get(url) - if err != nil { - return nil, err - } - defer res.Body.Close() - - if res.StatusCode != 200 { - err = fmt.Errorf("Code %d returned for url %s", res.StatusCode, url) - return nil, fmt.Errorf("Error querying AWS metadata for key %s: %v", key, err) - } - - body, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("Error querying AWS metadata for key %s: %v", key, err) - } - - return []byte(body), nil +// Implements EC2Metadata.GetMetadata +func (self *awsSdkMetadata) GetMetadata(path string) (string, error) { + return self.metadata.GetMetadata(path) } // Implements EC2.DescribeSecurityGroups @@ -466,7 +450,7 @@ func init() { } // readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. -func readAWSCloudConfig(config io.Reader, metadata AWSMetadata) (*AWSCloudConfig, error) { +func readAWSCloudConfig(config io.Reader, metadata EC2Metadata) (*AWSCloudConfig, error) { var cfg AWSCloudConfig var err error @@ -493,15 +477,8 @@ func readAWSCloudConfig(config io.Reader, metadata AWSMetadata) (*AWSCloudConfig return &cfg, nil } -func getAvailabilityZone(metadata AWSMetadata) (string, error) { - availabilityZoneBytes, err := metadata.GetMetaData("placement/availability-zone") - if err != nil { - return "", err - } - if availabilityZoneBytes == nil || len(availabilityZoneBytes) == 0 { - return "", fmt.Errorf("Unable to determine availability-zone from instance metadata") - } - return string(availabilityZoneBytes), nil +func getAvailabilityZone(metadata EC2Metadata) (string, error) { + return metadata.GetMetadata("placement/availability-zone") } func isRegionValid(region string) bool { @@ -527,7 +504,11 @@ func isRegionValid(region string) bool { // newAWSCloud creates a new instance of AWSCloud. // AWSProvider and instanceId are primarily for tests func newAWSCloud(config io.Reader, awsServices AWSServices) (*AWSCloud, error) { - metadata := awsServices.Metadata() + metadata, err := awsServices.Metadata() + if err != nil { + return nil, fmt.Errorf("error creating AWS metadata client: %v", err) + } + cfg, err := readAWSCloudConfig(config, metadata) if err != nil { return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) @@ -564,6 +545,7 @@ func newAWSCloud(config io.Reader, awsServices AWSServices) (*AWSCloud, error) { ec2: ec2, elb: elb, asg: asg, + metadata: metadata, cfg: cfg, region: regionName, availabilityZone: zone, @@ -1059,17 +1041,16 @@ func (s *AWSCloud) getSelfAWSInstance() (*awsInstance, error) { i := s.selfAWSInstance if i == nil { - metadata := s.awsServices.Metadata() - instanceIdBytes, err := metadata.GetMetaData("instance-id") + instanceId, err := s.metadata.GetMetadata("instance-id") if err != nil { return nil, fmt.Errorf("error fetching instance-id from ec2 metadata service: %v", err) } - privateDnsNameBytes, err := metadata.GetMetaData("local-hostname") + privateDnsName, err := s.metadata.GetMetadata("local-hostname") if err != nil { return nil, fmt.Errorf("error fetching local-hostname from ec2 metadata service: %v", err) } - i = newAWSInstance(s.ec2, string(instanceIdBytes), string(privateDnsNameBytes)) + i = newAWSInstance(s.ec2, instanceId, privateDnsName) s.selfAWSInstance = i } @@ -1254,27 +1235,24 @@ func (s *AWSCloud) describeLoadBalancer(name string) (*elb.LoadBalancerDescripti // Retrieves instance's vpc id from metadata func (self *AWSCloud) findVPCID() (string, error) { - - metadata := self.awsServices.Metadata() - macsBytes, err := metadata.GetMetaData("network/interfaces/macs/") + macs, err := self.metadata.GetMetadata("network/interfaces/macs/") if err != nil { return "", fmt.Errorf("Could not list interfaces of the instance", err) } // loop over interfaces, first vpc id returned wins - for _, macPath := range strings.Split(string(macsBytes), "\n") { - + for _, macPath := range strings.Split(macs, "\n") { if len(macPath) == 0 { continue } url := fmt.Sprintf("network/interfaces/macs/%svpc-id", macPath) - vpcIDBytes, err := metadata.GetMetaData(url) + vpcID, err := self.metadata.GetMetadata(url) if err != nil { continue } - return string(vpcIDBytes), nil + return vpcID, nil } - return "", fmt.Errorf("Could not find VPC id in instance metadata") + return "", fmt.Errorf("Could not find VPC ID in instance metadata") } // Find the VPC which self is attached to. diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index be6a94bb419..bb656bf4feb 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -83,9 +83,9 @@ func TestReadAWSCloudConfig(t *testing.T) { for _, test := range tests { t.Logf("Running test case %s", test.name) - var metadata AWSMetadata + var metadata EC2Metadata if test.aws != nil { - metadata = test.aws.Metadata() + metadata, _ = test.aws.Metadata() } cfg, err := readAWSCloudConfig(test.reader, metadata) if test.expectError { @@ -166,8 +166,8 @@ func (s *FakeAWSServices) Autoscaling(region string) (ASG, error) { return s.asg, nil } -func (s *FakeAWSServices) Metadata() AWSMetadata { - return s.metadata +func (s *FakeAWSServices) Metadata() (EC2Metadata, error) { + return s.metadata, nil } func TestFilterTags(t *testing.T) { @@ -313,31 +313,31 @@ type FakeMetadata struct { aws *FakeAWSServices } -func (self *FakeMetadata) GetMetaData(key string) ([]byte, error) { +func (self *FakeMetadata) GetMetadata(key string) (string, error) { networkInterfacesPrefix := "network/interfaces/macs/" if key == "placement/availability-zone" { - return []byte(self.aws.availabilityZone), nil + return self.aws.availabilityZone, nil } else if key == "instance-id" { - return []byte(self.aws.instanceId), nil + return self.aws.instanceId, nil } else if key == "local-hostname" { - return []byte(self.aws.privateDnsName), nil + return self.aws.privateDnsName, nil } else if strings.HasPrefix(key, networkInterfacesPrefix) { if key == networkInterfacesPrefix { - return []byte(strings.Join(self.aws.networkInterfacesMacs, "/\n") + "/\n"), nil + return strings.Join(self.aws.networkInterfacesMacs, "/\n") + "/\n", nil } else { keySplit := strings.Split(key, "/") macParam := keySplit[3] if len(keySplit) == 5 && keySplit[4] == "vpc-id" { for i, macElem := range self.aws.networkInterfacesMacs { if macParam == macElem { - return []byte(self.aws.networkInterfacesVpcIDs[i]), nil + return self.aws.networkInterfacesVpcIDs[i], nil } } } - return nil, nil + return "", nil } } else { - return nil, nil + return "", nil } }