diff --git a/pkg/cloudprovider/providers/aws/BUILD b/pkg/cloudprovider/providers/aws/BUILD index 7fb2d5e4eb5..4ce36c72273 100644 --- a/pkg/cloudprovider/providers/aws/BUILD +++ b/pkg/cloudprovider/providers/aws/BUILD @@ -38,6 +38,7 @@ go_library( "//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds:go_default_library", + "//vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/session:go_default_library", @@ -46,6 +47,7 @@ go_library( "//vendor/github.com/aws/aws-sdk-go/service/elb:go_default_library", "//vendor/github.com/aws/aws-sdk-go/service/elbv2:go_default_library", "//vendor/github.com/aws/aws-sdk-go/service/kms:go_default_library", + "//vendor/github.com/aws/aws-sdk-go/service/sts:go_default_library", "//vendor/github.com/golang/glog:go_default_library", "//vendor/github.com/prometheus/client_golang/prometheus:go_default_library", "//vendor/gopkg.in/gcfg.v1:go_default_library", diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 6e9b5ac718a..31faa871d68 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -33,6 +33,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" @@ -41,6 +42,7 @@ import ( "github.com/aws/aws-sdk-go/service/elb" "github.com/aws/aws-sdk-go/service/elbv2" "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go/service/sts" "github.com/golang/glog" "github.com/prometheus/client_golang/prometheus" clientset "k8s.io/client-go/kubernetes" @@ -526,6 +528,9 @@ type CloudConfig struct { // RouteTableID enables using a specific RouteTable RouteTableID string + // RoleARN is the IAM role to assume when interaction with AWS APIs. + RoleARN string + // KubernetesClusterTag is the legacy cluster id we'll use to identify our cluster resources KubernetesClusterTag string // KubernetesClusterID is the cluster id we'll use to identify our cluster resources @@ -927,22 +932,43 @@ func (s *awsSdkEC2) DescribeVpcs(request *ec2.DescribeVpcsInput) (*ec2.DescribeV func init() { registerMetrics() cloudprovider.RegisterCloudProvider(ProviderName, func(config io.Reader) (cloudprovider.Interface, error) { + cfg, err := readAWSCloudConfig(config) + if err != nil { + return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) + } + + sess, err := session.NewSession(&aws.Config{}) + if err != nil { + return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + } + + var provider credentials.Provider + if cfg.Global.RoleARN == "" { + provider = &ec2rolecreds.EC2RoleProvider{ + Client: ec2metadata.New(sess), + } + } else { + glog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN) + provider = &stscreds.AssumeRoleProvider{ + Client: sts.New(sess), + RoleARN: cfg.Global.RoleARN, + } + } + creds := credentials.NewChainCredentials( []credentials.Provider{ &credentials.EnvProvider{}, - &ec2rolecreds.EC2RoleProvider{ - Client: ec2metadata.New(session.New(&aws.Config{})), - }, + provider, &credentials.SharedCredentialsProvider{}, }) aws := newAWSSDKProvider(creds) - return newAWSCloud(config, aws) + return newAWSCloud(*cfg, aws) }) } // readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. -func readAWSCloudConfig(config io.Reader, metadata EC2Metadata) (*CloudConfig, error) { +func readAWSCloudConfig(config io.Reader) (*CloudConfig, error) { var cfg CloudConfig var err error @@ -953,20 +979,25 @@ func readAWSCloudConfig(config io.Reader, metadata EC2Metadata) (*CloudConfig, e } } + return &cfg, nil +} + +func updateConfigZone(cfg *CloudConfig, metadata EC2Metadata) error { if cfg.Global.Zone == "" { if metadata != nil { glog.Info("Zone not specified in configuration file; querying AWS metadata service") + var err error cfg.Global.Zone, err = getAvailabilityZone(metadata) if err != nil { - return nil, err + return err } } if cfg.Global.Zone == "" { - return nil, fmt.Errorf("no zone specified in configuration file") + return fmt.Errorf("no zone specified in configuration file") } } - return &cfg, nil + return nil } func getInstanceType(metadata EC2Metadata) (string, error) { @@ -989,7 +1020,7 @@ func azToRegion(az string) (string, error) { // newAWSCloud creates a new instance of AWSCloud. // AWSProvider and instanceId are primarily for tests -func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { +func newAWSCloud(cfg CloudConfig, awsServices Services) (*Cloud, error) { // We have some state in the Cloud object - in particular the attaching map // Log so that if we are building multiple Cloud objects, it is obvious! glog.Infof("Building AWS cloudprovider") @@ -999,9 +1030,9 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { return nil, fmt.Errorf("error creating AWS metadata client: %q", err) } - cfg, err := readAWSCloudConfig(config, metadata) + err = updateConfigZone(&cfg, metadata) if err != nil { - return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) + return nil, fmt.Errorf("unable to determine AWS zone from cloud provider config or EC2 instance metadata: %v", err) } zone := cfg.Global.Zone @@ -1059,7 +1090,7 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { asg: asg, metadata: metadata, kms: kms, - cfg: cfg, + cfg: &cfg, region: regionName, attaching: make(map[types.NodeName]map[mountDevice]awsVolumeID), @@ -1067,8 +1098,9 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { } awsCloud.instanceCache.cloud = awsCloud - if cfg.Global.VPC != "" && cfg.Global.SubnetID != "" && (cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "") { - // When the master is running on a different AWS account, cloud provider or on-premises + tagged := cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "" + if cfg.Global.VPC != "" && (cfg.Global.SubnetID != "" || cfg.Global.RoleARN != "") && tagged { + // When the master is running on a different AWS account, cloud provider or on-premise // build up a dummy instance and use the VPC from the nodes account glog.Info("Master is configured to run on a different AWS account, different cloud provider or on-premises") awsCloud.selfAWSInstance = &awsInstance{ @@ -1084,7 +1116,6 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { } awsCloud.selfAWSInstance = selfAWSInstance awsCloud.vpcID = selfAWSInstance.vpcID - } if cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "" { diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index f7cc8012dec..5e9601730b7 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -160,7 +160,10 @@ func TestReadAWSCloudConfig(t *testing.T) { if test.aws != nil { metadata, _ = test.aws.Metadata() } - cfg, err := readAWSCloudConfig(test.reader, metadata) + cfg, err := readAWSCloudConfig(test.reader) + if err == nil { + err = updateConfigZone(cfg, metadata) + } if test.expectError { if err == nil { t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) @@ -213,7 +216,11 @@ func TestNewAWSCloud(t *testing.T) { for _, test := range tests { t.Logf("Running test case %s", test.name) - c, err := newAWSCloud(test.reader, test.awsServices) + cfg, err := readAWSCloudConfig(test.reader) + var c *Cloud + if err == nil { + c, err = newAWSCloud(*cfg, test.awsServices) + } if test.expectError { if err == nil { t.Errorf("Should error for case %s", test.name) @@ -233,7 +240,7 @@ func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (* awsServices := newMockedFakeAWSServices(TestClusterId) awsServices.instances = instances awsServices.selfInstance = selfInstance - awsCloud, err := newAWSCloud(nil, awsServices) + awsCloud, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { panic(err) } @@ -242,7 +249,7 @@ func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (* func mockAvailabilityZone(availabilityZone string) *Cloud { awsServices := newMockedFakeAWSServices(TestClusterId).WithAz(availabilityZone) - awsCloud, err := newAWSCloud(nil, awsServices) + awsCloud, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { panic(err) } @@ -389,7 +396,7 @@ func TestGetRegion(t *testing.T) { func TestFindVPCID(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -463,7 +470,7 @@ func constructRouteTable(subnetID string, public bool) *ec2.RouteTable { func TestSubnetIDsinVPC(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -798,7 +805,7 @@ func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) { instances := []*ec2.Instance{&terminatedInstance, &runningInstance} awsServices.instances = append(awsServices.instances, instances...) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -818,7 +825,7 @@ func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) { func TestGetInstanceByNodeNameBatching(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) var tag ec2.Tag tag.Key = aws.String(TagNameKubernetesClusterPrefix + TestClusterId) @@ -845,7 +852,7 @@ func TestGetInstanceByNodeNameBatching(t *testing.T) { func TestGetVolumeLabels(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) volumeId := awsVolumeID("vol-VolumeId") expectedVolumeRequest := &ec2.DescribeVolumesInput{VolumeIds: []*string{volumeId.awsString()}} @@ -867,7 +874,7 @@ func TestGetVolumeLabels(t *testing.T) { func TestDescribeLoadBalancerOnDelete(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.EnsureLoadBalancerDeleted(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}) @@ -875,7 +882,7 @@ func TestDescribeLoadBalancerOnDelete(t *testing.T) { func TestDescribeLoadBalancerOnUpdate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.UpdateLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{}) @@ -883,7 +890,7 @@ func TestDescribeLoadBalancerOnUpdate(t *testing.T) { func TestDescribeLoadBalancerOnGet(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.GetLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}) @@ -891,7 +898,7 @@ func TestDescribeLoadBalancerOnGet(t *testing.T) { func TestDescribeLoadBalancerOnEnsure(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.EnsureLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{}) @@ -1123,7 +1130,7 @@ func TestGetLoadBalancerAdditionalTags(t *testing.T) { func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) sg1 := "sg-000001" sg2 := "sg-000002" @@ -1159,7 +1166,7 @@ func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { func TestAddLoadBalancerTags(t *testing.T) { loadBalancerName := "test-elb" awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) want := make(map[string]string) want["tag1"] = "val1" @@ -1215,7 +1222,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC if test.overriddenFieldName != "" { // cater for test case with no overrides @@ -1233,7 +1240,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("does not make an API call if the current health check is the same", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC timeout := int64(3) @@ -1255,7 +1262,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("validates resulting expected health check before making an API call", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC invalidThreshold := int64(1) @@ -1271,7 +1278,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("handles invalid override values", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "3.3"} @@ -1283,7 +1290,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("returns error when updating the health check fails", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) returnErr := fmt.Errorf("throttling error") awsServices.elb.(*MockedFakeELB).expectConfigureHealthCheck(&lbName, defaultHC, returnErr) diff --git a/pkg/cloudprovider/providers/aws/regions_test.go b/pkg/cloudprovider/providers/aws/regions_test.go index 50352f754f9..03fb8ff16ab 100644 --- a/pkg/cloudprovider/providers/aws/regions_test.go +++ b/pkg/cloudprovider/providers/aws/regions_test.go @@ -74,7 +74,7 @@ func TestRecognizesNewRegion(t *testing.T) { } awsServices := NewFakeAWSServices(TestClusterId).WithAz(region + "a") - _, err := newAWSCloud(nil, awsServices) + _, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("error building AWS cloud: %v", err) } diff --git a/pkg/cloudprovider/providers/aws/tags_test.go b/pkg/cloudprovider/providers/aws/tags_test.go index 42185a4f941..c745451431b 100644 --- a/pkg/cloudprovider/providers/aws/tags_test.go +++ b/pkg/cloudprovider/providers/aws/tags_test.go @@ -19,13 +19,12 @@ package aws import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" - "strings" "testing" ) func TestFilterTags(t *testing.T) { awsServices := NewFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("Error building aws cloud: %v", err) return