diff --git a/pkg/cloudprovider/providers/aws/BUILD b/pkg/cloudprovider/providers/aws/BUILD index 35b7d5480bf..a53e5242067 100644 --- a/pkg/cloudprovider/providers/aws/BUILD +++ b/pkg/cloudprovider/providers/aws/BUILD @@ -51,6 +51,7 @@ go_library( "//vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library", + "//vendor/github.com/aws/aws-sdk-go/aws/endpoints:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/session:go_default_library", "//vendor/github.com/aws/aws-sdk-go/service/autoscaling:go_default_library", diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 540540c7e9f..c1add689605 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -34,6 +34,7 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/autoscaling" @@ -42,7 +43,7 @@ import ( "github.com/aws/aws-sdk-go/service/elbv2" "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/sts" - gcfg "gopkg.in/gcfg.v1" + "gopkg.in/gcfg.v1" "k8s.io/klog" "k8s.io/api/core/v1" @@ -55,7 +56,7 @@ import ( "k8s.io/client-go/kubernetes/scheme" v1core "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/record" - cloudprovider "k8s.io/cloud-provider" + "k8s.io/cloud-provider" "k8s.io/kubernetes/pkg/api/v1/service" "k8s.io/kubernetes/pkg/controller" kubeletapis "k8s.io/kubernetes/pkg/kubelet/apis" @@ -571,6 +572,91 @@ type CloudConfig struct { //issue body. DisableStrictZoneCheck bool } + // [ServiceOverride "1"] + // Service = s3 + // Region = region1 + // URL = https://s3.foo.bar + // SigningRegion = signing_region + // SigningMethod = signing_method + // + // [ServiceOverride "2"] + // Service = ec2 + // Region = region2 + // URL = https://ec2.foo.bar + // SigningRegion = signing_region + // SigningMethod = signing_method + ServiceOverride map[string]*struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + } +} + +func (cfg *CloudConfig) validateOverrides() error { + if len(cfg.ServiceOverride) == 0 { + return nil + } + set := make(map[string]bool) + for onum, ovrd := range cfg.ServiceOverride { + // Note: gcfg does not space trim, so we have to when comparing to empty string "" + name := strings.TrimSpace(ovrd.Service) + if name == "" { + return fmt.Errorf("service name is missing [Service is \"\"] in override %s", onum) + } + // insure the map service name is space trimmed + ovrd.Service = name + + region := strings.TrimSpace(ovrd.Region) + if region == "" { + return fmt.Errorf("service region is missing [Region is \"\"] in override %s", onum) + } + // insure the map region is space trimmed + ovrd.Region = region + + url := strings.TrimSpace(ovrd.URL) + if url == "" { + return fmt.Errorf("url is missing [URL is \"\"] in override %s", onum) + } + signingRegion := strings.TrimSpace(ovrd.SigningRegion) + if signingRegion == "" { + return fmt.Errorf("signingRegion is missing [SigningRegion is \"\"] in override %s", onum) + } + signature := name + "_" + region + if set[signature] { + return fmt.Errorf("duplicate entry found for service override [%s] (%s in %s)", onum, name, region) + } + set[signature] = true + } + return nil +} + +func (cfg *CloudConfig) getResolver() endpoints.ResolverFunc { + defaultResolver := endpoints.DefaultResolver() + defaultResolverFn := func(service, region string, + optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + return defaultResolver.EndpointFor(service, region, optFns...) + } + if len(cfg.ServiceOverride) == 0 { + return defaultResolverFn + } + + return func(service, region string, + optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + for _, override := range cfg.ServiceOverride { + if override.Service == service && override.Region == region { + return endpoints.ResolvedEndpoint{ + URL: override.URL, + SigningRegion: override.SigningRegion, + SigningMethod: override.SigningMethod, + SigningName: override.SigningName, + }, nil + } + } + return defaultResolver.EndpointFor(service, region, optFns...) + } } // awsSdkEC2 is an implementation of the EC2 interface, backed by aws-sdk-go @@ -578,16 +664,23 @@ type awsSdkEC2 struct { ec2 *ec2.EC2 } +// Interface to make the CloudConfig immutable for awsSDKProvider +type awsCloudConfigProvider interface { + getResolver() endpoints.ResolverFunc +} + type awsSDKProvider struct { creds *credentials.Credentials + cfg awsCloudConfigProvider mutex sync.Mutex regionDelayers map[string]*CrossRequestRetryDelay } -func newAWSSDKProvider(creds *credentials.Credentials) *awsSDKProvider { +func newAWSSDKProvider(creds *credentials.Credentials, cfg *CloudConfig) *awsSDKProvider { return &awsSDKProvider{ creds: creds, + cfg: cfg, regionDelayers: make(map[string]*CrossRequestRetryDelay), } } @@ -657,7 +750,8 @@ func (p *awsSDKProvider) Compute(regionName string) (EC2, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(p.cfg.getResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -678,7 +772,8 @@ func (p *awsSDKProvider) LoadBalancing(regionName string) (ELB, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(p.cfg.getResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -695,7 +790,8 @@ func (p *awsSDKProvider) LoadBalancingV2(regionName string) (ELBV2, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(p.cfg.getResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -713,7 +809,8 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(p.cfg.getResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -727,7 +824,9 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { } func (p *awsSDKProvider) Metadata() (EC2Metadata, error) { - sess, err := session.NewSession(&aws.Config{}) + sess, err := session.NewSession(&aws.Config{ + EndpointResolver: p.cfg.getResolver(), + }) if err != nil { return nil, fmt.Errorf("unable to initialize AWS session: %v", err) } @@ -741,7 +840,8 @@ func (p *awsSDKProvider) KeyManagement(regionName string) (KMS, error) { Region: ®ionName, Credentials: p.creds, } - awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true) + awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). + WithEndpointResolver(p.cfg.getResolver()) sess, err := session.NewSession(awsConfig) if err != nil { @@ -966,6 +1066,10 @@ func init() { return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) } + if err = cfg.validateOverrides(); err != nil { + return nil, fmt.Errorf("unable to validate custom endpoint overrides: %v", err) + } + sess, err := session.NewSession(&aws.Config{}) if err != nil { return nil, fmt.Errorf("unable to initialize AWS session: %v", err) @@ -991,7 +1095,7 @@ func init() { &credentials.SharedCredentialsProvider{}, }) - aws := newAWSSDKProvider(creds) + aws := newAWSSDKProvider(creds, cfg) return newAWSCloud(*cfg, aws) }) } diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 4d12932859b..62f7186af05 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -186,6 +186,289 @@ func TestReadAWSCloudConfig(t *testing.T) { } } +type ServiceDescriptor struct { + name string + region string + signingRegion, signingMethod string + signingName string +} + +func TestOverridesActiveConfig(t *testing.T) { + tests := []struct { + name string + + reader io.Reader + aws Services + + expectError bool + active bool + servicesOverridden []ServiceDescriptor + }{ + { + "No overrides", + strings.NewReader(` + [global] + `), + nil, + false, false, + []ServiceDescriptor{}, + }, + { + "Missing Service Name", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Region=sregion + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing Service Region", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing URL", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service="s3" + Region=sregion + SigningRegion=sregion + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing Signing Region", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + Region=sregion + URL=https://s3.foo.bar + SigningMethod = sign + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Active Overrides", + strings.NewReader(` + [Global] + + [ServiceOverride "1"] + Service = "s3 " + Region = sregion + URL = https://s3.foo.bar + SigningRegion = sregion + SigningMethod = v4 + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "sregion", signingRegion: "sregion", signingMethod: "v4"}}, + }, + { + "Multiple Overridden Services", + strings.NewReader(` + [Global] + vpc = vpc-abc1234567 + + [ServiceOverride "1"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v4 + + [ServiceOverride "2"] + Service=ec2 + Region=sregion2 + URL=https://ec2.foo.bar + SigningRegion=sregion2 + SigningMethod = v4`), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "sregion1", signingRegion: "sregion1", signingMethod: "v4"}, + {name: "ec2", region: "sregion2", signingRegion: "sregion2", signingMethod: "v4"}}, + }, + { + "Duplicate Services", + strings.NewReader(` + [Global] + vpc = vpc-abc1234567 + + [ServiceOverride "1"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign + + [ServiceOverride "2"] + Service=s3 + Region=sregion1 + URL=https://s3.foo.bar + SigningRegion=sregion + SigningMethod = sign`), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Multiple Overridden Services in Multiple regions", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + Region=region1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + + [ServiceOverride "2"] + Service=ec2 + Region=region2 + URL=https://ec2.foo.bar + SigningRegion=sregion + SigningMethod = v4 + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: ""}, + {name: "ec2", region: "region2", signingRegion: "sregion", signingMethod: "v4"}}, + }, + { + "Multiple regions, Same Service", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + Region=region1 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v3 + + [ServiceOverride "2"] + Service=s3 + Region=region2 + URL=https://s3.foo.bar + SigningRegion=sregion1 + SigningMethod = v4 + SigningName = "name" + `), + nil, + false, true, + []ServiceDescriptor{{name: "s3", region: "region1", signingRegion: "sregion1", signingMethod: "v3"}, + {name: "s3", region: "region2", signingRegion: "sregion1", signingMethod: "v4", signingName: "name"}}, + }, + } + + for _, test := range tests { + t.Logf("Running test case %s", test.name) + cfg, err := readAWSCloudConfig(test.reader) + if err == nil { + err = cfg.validateOverrides() + } + if test.expectError { + if err == nil { + t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) + } + } else { + if err != nil { + t.Errorf("Should succeed for case: %s, got %v", test.name, err) + } + + if len(cfg.ServiceOverride) != len(test.servicesOverridden) { + t.Errorf("Expected %d overridden services, received %d for case %s", + len(test.servicesOverridden), len(cfg.ServiceOverride), test.name) + } else { + for _, sd := range test.servicesOverridden { + var found *struct { + Service string + Region string + URL string + SigningRegion string + SigningMethod string + SigningName string + } + for _, v := range cfg.ServiceOverride { + if v.Service == sd.name && v.Region == sd.region { + found = v + break + } + } + if found == nil { + t.Errorf("Missing override for service %s in case %s", + sd.name, test.name) + } else { + if found.SigningRegion != sd.signingRegion { + t.Errorf("Expected signing region '%s', received '%s' for case %s", + sd.signingRegion, found.SigningRegion, test.name) + } + if found.SigningMethod != sd.signingMethod { + t.Errorf("Expected signing method '%s', received '%s' for case %s", + sd.signingMethod, found.SigningRegion, test.name) + } + targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) + if found.URL != targetName { + t.Errorf("Expected Endpoint '%s', received '%s' for case %s", + targetName, found.URL, test.name) + } + if found.SigningName != sd.signingName { + t.Errorf("Expected signing name '%s', received '%s' for case %s", + sd.signingName, found.SigningName, test.name) + } + + fn := cfg.getResolver() + ep1, e := fn(sd.name, sd.region, nil) + if e != nil { + t.Errorf("Expected a valid endpoint for %s in case %s", + sd.name, test.name) + } else { + targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) + if ep1.URL != targetName { + t.Errorf("Expected endpoint url: %s, received %s in case %s", + targetName, ep1.URL, test.name) + } + if ep1.SigningRegion != sd.signingRegion { + t.Errorf("Expected signing region '%s', received '%s' in case %s", + sd.signingRegion, ep1.SigningRegion, test.name) + } + if ep1.SigningMethod != sd.signingMethod { + t.Errorf("Expected signing method '%s', received '%s' in case %s", + sd.signingMethod, ep1.SigningRegion, test.name) + } + } + } + } + } + } + } +} + func TestNewAWSCloud(t *testing.T) { tests := []struct { name string