diff --git a/pkg/cloudprovider/providers/aws/BUILD b/pkg/cloudprovider/providers/aws/BUILD index 755d731a0fe..c2207efaad3 100644 --- a/pkg/cloudprovider/providers/aws/BUILD +++ b/pkg/cloudprovider/providers/aws/BUILD @@ -50,6 +50,7 @@ go_library( "//vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library", + "//vendor/github.com/aws/aws-sdk-go/aws/endpoints:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/session:go_default_library", "//vendor/github.com/aws/aws-sdk-go/service/autoscaling:go_default_library", diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 34b137947f9..66387d8573a 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -570,89 +570,76 @@ type CloudConfig struct { //yourself in an non-AWS cloud and open an issue, please indicate that in the //issue body. DisableStrictZoneCheck bool - - // Delimiter to use to separate region of occurrence, url and signing region for each override - // NOTE: semi-colon ';' truncates the input line in INI files, do not use ';' - // Defaults to "," - OverrideSeparator string - - // These are of format servicename, region, url, signing_region - // s3, region1, https://s3.foo.bar, some signing_region - // ec2 region1, https://ec2.foo.bar, signing_region - ServiceOverrides []string + } + // [ServiceOverride "1"] + // Name = s3 + // Region = region1 + // Url = https://s3.foo.bar + // SigningRegion = signing_region + // + // [ServiceOverride "2"] + // Name = ec2 + // Region = region2 + // Url = https://ec2.foo.bar + // SigningRegion = signing_region + ServiceOverride map[string]*struct { + Service string + Region string + Url string + SigningRegion string } } -const ( - OverrideSeparatorDefault = "," -) - -type CustomEndpoint struct { - Endpoint string - SigningRegion string -} - -var overridesActive = false -var overrides map[string]CustomEndpoint - -func setOverridesDefaults(cfg *CloudConfig) error { - if cfg.Global.OverrideSeparator == "" { - cfg.Global.OverrideSeparator = OverrideSeparatorDefault - } else if cfg.Global.OverrideSeparator == ";" { - return fmt.Errorf("semi-colon may not be used as a override separator, it truncates the input") - } - return nil -} - -func makeRegionEndpointSignature(serviceName, region string) string { - return fmt.Sprintf("%s__%s", strings.TrimSpace(serviceName), strings.TrimSpace(region)) -} - -func parseOverrides(cfg *CloudConfig) error { - overridesActive = false - if len(cfg.Global.ServiceOverrides) == 0 { +func (cfg *CloudConfig) validateOverrides() error { + if len(cfg.ServiceOverride) == 0 { return nil } - if err := setOverridesDefaults(cfg); err != nil { - return err - } - overrides = make(map[string]CustomEndpoint) - for _, ovrd := range cfg.Global.ServiceOverrides { - tokens := strings.Split(ovrd, cfg.Global.OverrideSeparator) - if len(tokens) != 4 { - if len(tokens) > 0 { - return fmt.Errorf("4 parameters (service, region, url, signing region) are required for [%s] in %s", - tokens[0], ovrd) - } - return fmt.Errorf("4 parameters (service, region, url, signing region) are required in %s", - ovrd) + set := make(map[string]bool) + for onum, ovrd := range cfg.ServiceOverride { + name := strings.TrimSpace(ovrd.Service) + if name == "" { + return fmt.Errorf("service name is missing [Service is \"\"] in override %s", onum) + } + region := strings.TrimSpace(ovrd.Region) + if region == "" { + return fmt.Errorf("service region is missing [Region is \"\"] in override %s", onum) + } + url := strings.TrimSpace(ovrd.Url) + if url == "" { + return fmt.Errorf("url is missing [Url is \"\"] in override %s", onum) + } + signingRegion := strings.TrimSpace(ovrd.SigningRegion) + if signingRegion == "" { + return fmt.Errorf("signingRegion is missing [SigningRegion is \"\"] in override %s", onum) + } + if _, ok := set[name+"_"+region]; ok { + return fmt.Errorf("duplicate entry found for service override [%s] (%s in %s)", onum, name, region) + } else { + set[name+"_"+region] = true } - name := strings.TrimSpace(tokens[0]) - region := strings.TrimSpace(tokens[1]) - url := strings.TrimSpace(tokens[2]) - signingRegion := strings.TrimSpace(tokens[3]) - signature := makeRegionEndpointSignature(name, region) - overrides[signature] = CustomEndpoint{Endpoint: url, SigningRegion: signingRegion} } - overridesActive = true return nil } -func loadCustomResolver() func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { +func (cfg *CloudConfig) getResolver() func(service, region string, + optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { defaultResolver := endpoints.DefaultResolver() - defaultResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + defaultResolverFn := func(service, region string, + optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { return defaultResolver.EndpointFor(service, region, optFns...) } - if !overridesActive { + if len(cfg.ServiceOverride) == 0 { return defaultResolverFn } - customResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { - signature := makeRegionEndpointSignature(service, region) - if ep, ok := overrides[signature]; ok { - return endpoints.ResolvedEndpoint{ - URL: ep.Endpoint, - SigningRegion: ep.SigningRegion, - }, nil + customResolverFn := func(service, region string, + optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + for _, override := range cfg.ServiceOverride { + if override.Service == service && override.Region == region { + return endpoints.ResolvedEndpoint{ + URL: override.Url, + SigningRegion: override.SigningRegion, + }, nil + } } return defaultResolver.EndpointFor(service, region, optFns...) } @@ -664,16 +651,23 @@ type awsSdkEC2 struct { ec2 *ec2.EC2 } +// Interface to make the CloudConfig immutable for awsSDKProvider +type awsCloudConfigProvider interface { + getResolver() func(string, string, ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) +} + type awsSDKProvider struct { creds *credentials.Credentials + cfg awsCloudConfigProvider mutex sync.Mutex regionDelayers map[string]*CrossRequestRetryDelay } -func newAWSSDKProvider(creds *credentials.Credentials) *awsSDKProvider { +func newAWSSDKProvider(creds *credentials.Credentials, cfg *CloudConfig) *awsSDKProvider { return &awsSDKProvider{ creds: creds, + cfg: cfg, regionDelayers: make(map[string]*CrossRequestRetryDelay), } } @@ -739,7 +733,7 @@ func (p *awsSDKProvider) Compute(regionName string) (EC2, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) + WithEndpointResolver(endpoints.ResolverFunc(p.cfg.getResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -761,7 +755,7 @@ func (p *awsSDKProvider) LoadBalancing(regionName string) (ELB, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) + WithEndpointResolver(endpoints.ResolverFunc(p.cfg.getResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -779,7 +773,7 @@ func (p *awsSDKProvider) LoadBalancingV2(regionName string) (ELBV2, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) + WithEndpointResolver(endpoints.ResolverFunc(p.cfg.getResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -798,7 +792,7 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) + WithEndpointResolver(endpoints.ResolverFunc(p.cfg.getResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -813,7 +807,7 @@ func (p *awsSDKProvider) Autoscaling(regionName string) (ASG, error) { func (p *awsSDKProvider) Metadata() (EC2Metadata, error) { sess, err := session.NewSession(&aws.Config{ - EndpointResolver: endpoints.ResolverFunc(loadCustomResolver()), + EndpointResolver: endpoints.ResolverFunc(p.cfg.getResolver()), }) if err != nil { return nil, fmt.Errorf("unable to initialize AWS session: %v", err) @@ -829,7 +823,7 @@ func (p *awsSDKProvider) KeyManagement(regionName string) (KMS, error) { Credentials: p.creds, } awsConfig = awsConfig.WithCredentialsChainVerboseErrors(true). - WithEndpointResolver(endpoints.ResolverFunc(loadCustomResolver())) + WithEndpointResolver(endpoints.ResolverFunc(p.cfg.getResolver())) sess, err := session.NewSession(awsConfig) if err != nil { @@ -1054,8 +1048,8 @@ func init() { return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) } - if err = parseOverrides(cfg); err != nil { - return nil, fmt.Errorf("unable to parse custom endpoint overrides: %v", err) + if err = cfg.validateOverrides(); err != nil { + return nil, fmt.Errorf("unable to validate custom endpoint overrides: %v", err) } sess, err := session.NewSession(&aws.Config{}) @@ -1083,7 +1077,7 @@ func init() { &credentials.SharedCredentialsProvider{}, }) - aws := newAWSSDKProvider(creds) + aws := newAWSSDKProvider(creds, cfg) return newAWSCloud(*cfg, aws) }) } diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 3d700884050..d26d4a20a9b 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -187,7 +187,7 @@ func TestReadAWSCloudConfig(t *testing.T) { } type ServiceDescriptor struct { - name string + name string region string } @@ -203,57 +203,165 @@ func TestOverridesActiveConfig(t *testing.T) { servicesOverridden []ServiceDescriptor }{ { - "Missing Servicename Separator", - strings.NewReader("[global]\nServiceOverrides=s3sregion, https://s3.foo.bar, sregion"), + "No overrides", + strings.NewReader(` + [global] + `), + nil, + false, false, + []ServiceDescriptor{}, + }, + { + "Missing Service Name", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Region=sregion + Url=https://s3.foo.bar + SigningRegion=sregion + `), nil, true, false, []ServiceDescriptor{}, }, { "Missing Service Region", - strings.NewReader("[global]\nServiceOverrides=s3, https://s3.foo.bar, sregion"), + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + Url=https://s3.foo.bar + SigningRegion=sregion + `), nil, true, false, []ServiceDescriptor{}, }, { - "Semi-colon in override delimiter", - strings.NewReader("[global]\nOverrideSeparator=;\n" + - "ServiceOverrides=s3, https://s3.foo.bar, sregion"), + "Missing URL", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + Region=sregion + SigningRegion=sregion + `), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Missing Signing Region", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + Region=sregion + Url=https://s3.foo.bar + `), nil, true, false, []ServiceDescriptor{}, }, { "Active Overrides", - strings.NewReader("[global]\nServiceOverrides=s3, sregion, https://s3.foo.bar, sregion"), + strings.NewReader(` + [Global] + + [ServiceOverride "1"] + Service = s3 + Region = sregion + Url = https://s3.foo.bar + SigningRegion = sregion + `), nil, false, true, []ServiceDescriptor{{name: "s3", region: "sregion"}}, }, { - "Multiple Overriden Services", - strings.NewReader("[global]\n" + - "ServiceOverrides=s3, sregion1, https://s3.foo.bar, sregion\n" + - "ServiceOverrides=ec2, sregion2, https://ec2.foo.bar, sregion"), + "Multiple Overridden Services", + strings.NewReader(` + [Global] + vpc = vpc-abc1234567 + + [ServiceOverride "1"] + Service=s3 + Region=sregion1 + Url=https://s3.foo.bar + SigningRegion=sregion + + [ServiceOverride "2"] + Service=ec2 + Region=sregion2 + Url=https://ec2.foo.bar + SigningRegion=sregion`), nil, false, true, []ServiceDescriptor{{"s3", "sregion1"}, {"ec2", "sregion2"}}, }, { - "Multiple Overriden Services in Multiple regions", - strings.NewReader("[global]\n" + - "ServiceOverrides=s3, region1, https://s3.foo.bar, sregion\n" + - "ServiceOverrides=ec2, region2, https://ec2.foo.bar, sregion"), + "Duplicate Services", + strings.NewReader(` + [Global] + vpc = vpc-abc1234567 + + [ServiceOverride "1"] + Service=s3 + Region=sregion1 + Url=https://s3.foo.bar + SigningRegion=sregion + + [ServiceOverride "2"] + Service=s3 + Region=sregion1 + Url=https://s3.foo.bar + SigningRegion=sregion`), + nil, + true, false, + []ServiceDescriptor{}, + }, + { + "Multiple Overridden Services in Multiple regions", + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + Region=region1 + Url=https://s3.foo.bar + SigningRegion=sregion + + [ServiceOverride "2"] + Service=ec2 + Region=region2 + Url=https://ec2.foo.bar + SigningRegion=sregion + `), nil, false, true, - []ServiceDescriptor{{"s3","region1"}, {"ec2", "region2"}}, + []ServiceDescriptor{{"s3", "region1"}, {"ec2", "region2"}}, }, { "Multiple regions, Same Service", - strings.NewReader("[global]\n" + - "ServiceOverrides=s3, region1, https://s3.foo.bar, sregion\n" + - "ServiceOverrides=s3, region2, https://s3.foo.bar, sregion"), + strings.NewReader(` + [global] + + [ServiceOverride "1"] + Service=s3 + Region=region1 + Url=https://s3.foo.bar + SigningRegion=sregion + + [ServiceOverride "2"] + Service=s3 + Region=region2 + Url=https://s3.foo.bar + SigningRegion=sregion + `), nil, false, true, []ServiceDescriptor{{"s3", "region1"}, {"s3", "region2"}}, @@ -264,162 +372,62 @@ func TestOverridesActiveConfig(t *testing.T) { t.Logf("Running test case %s", test.name) cfg, err := readAWSCloudConfig(test.reader) if err == nil { - err = parseOverrides(cfg) + err = cfg.validateOverrides() } if test.expectError { if err == nil { t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) } - if overridesActive != test.active { - t.Errorf("Incorrect active flag (%v vs %v) for case: %s", - overridesActive, test.active, test.name) - } } else { if err != nil { - t.Errorf("Should succeed for case: %s", test.name) + t.Errorf("Should succeed for case: %s, got %v", test.name, err) } - if overridesActive != test.active { - t.Errorf("Incorrect active flag (%v vs %v) for case: %s", - overridesActive, test.active, test.name) - } else { - if len(overrides) != len(test.servicesOverridden) { - t.Errorf("Expected %d overridden services, received %d for case %s", - len(test.servicesOverridden), len(overrides), test.name) - } else { - for _, sd := range test.servicesOverridden { - signature := makeRegionEndpointSignature(sd.name, sd.region) - ep, ok := overrides[signature] - if !ok { - t.Errorf("Missing override for service %s in case %s", - sd.name, test.name) - } else { - if ep.SigningRegion != "sregion" { - t.Errorf("Expected signing region 'sregion', received '%s' for case %s", - ep.SigningRegion, test.name) - } - targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) - if ep.Endpoint != targetName { - t.Errorf("Expected Endpoint '%s', received '%s' for case %s", - targetName, ep.Endpoint, test.name) - } - fn := loadCustomResolver() - ep1, e := fn(sd.name, sd.region, nil) - if e != nil { - t.Errorf("Expected a valid endpoint for %s in case %s", - sd.name, test.name) - } else { - targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) - if ep1.URL != targetName { - t.Errorf("Expected endpoint url: %s, received %s in case %s", - targetName, ep1.URL, test.name) - } - if ep1.SigningRegion != "sregion" { - t.Errorf("Expected signing region 'sregion', received '%s' in case %s", - ep1.SigningRegion, test.name) - } - } + if len(cfg.ServiceOverride) != len(test.servicesOverridden) { + t.Errorf("Expected %d overridden services, received %d for case %s", + len(test.servicesOverridden), len(cfg.ServiceOverride), test.name) + } else { + for _, sd := range test.servicesOverridden { + var found *struct { + Service string + Region string + Url string + SigningRegion string + } + for _, v := range cfg.ServiceOverride { + if v.Service == sd.name && v.Region == sd.region { + found = v + break } } - } - } - } - } -} + if found == nil { + t.Errorf("Missing override for service %s in case %s", + sd.name, test.name) + } else { + if found.SigningRegion != "sregion" { + t.Errorf("Expected signing region 'sregion', received '%s' for case %s", + found.SigningRegion, test.name) + } + targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) + if found.Url != targetName { + t.Errorf("Expected Endpoint '%s', received '%s' for case %s", + targetName, found.Url, test.name) + } -func TestOverridesDefaults(t *testing.T) { - tests := []struct { - name string - - configString string - - expectError bool - active bool - servicesOverridden []string - defaults []string - }{ - { - "Custom OverrideSeparator", - "[global]\n" + - "ServiceOverrides=s3 + sregion + https://s3.foo.bar + sregion \n" + - "OverrideSeparator=+", - false, true, - []string{"s3"}, - []string{"+"}, - }, - { - "Active Overrides", - "[global]\n" + - "ServiceOverrides=s3, sregion, https://s3.foo.bar , sregion\n" + - "ServiceOverrides=ec2, sregion, https://ec2.foo.bar, sregion", - false, true, - []string{"s3", "ec2"}, - []string{","}, - }, - } - - for _, test := range tests { - t.Logf("Running test case %s", test.name) - cfg, err := readAWSCloudConfig(strings.NewReader(test.configString)) - if err == nil { - err = parseOverrides(cfg) - } - if test.expectError { - if err == nil { - t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) - } - if overridesActive != test.active { - t.Errorf("Incorrect active flag (%v vs %v) for case: %s", - overridesActive, test.active, test.name) - } - } else { - if err != nil { - t.Errorf("Should succeed for case: %s", test.name) - } - if overridesActive != test.active { - t.Errorf("Incorrect active flag (%v vs %v) for case: %s", - overridesActive, test.active, test.name) - } else { - if cfg.Global.OverrideSeparator != test.defaults[0] { - t.Errorf("Incorrect OverrideSeparator (%s vs %s) for case %s", - cfg.Global.OverrideSeparator, test.defaults[0], test.name) - } - if len(overrides) != len(test.servicesOverridden) { - t.Errorf("Expected %d overridden services, received %d for case %s", - len(test.servicesOverridden), len(overrides), test.name) - } else { - for _, name := range test.servicesOverridden { - signature := makeRegionEndpointSignature(name, "sregion") - ep, ok := overrides[signature] - if !ok { - t.Errorf("Missing override for service %s in case %s", - name, test.name) + fn := cfg.getResolver() + ep1, e := fn(sd.name, sd.region, nil) + if e != nil { + t.Errorf("Expected a valid endpoint for %s in case %s", + sd.name, test.name) } else { - if ep.SigningRegion != "sregion" { - t.Errorf("Expected signing region 'sregion', received '%s' for case %s", - ep.SigningRegion, test.name) + targetName := fmt.Sprintf("https://%s.foo.bar", sd.name) + if ep1.URL != targetName { + t.Errorf("Expected endpoint url: %s, received %s in case %s", + targetName, ep1.URL, test.name) } - targetName := fmt.Sprintf("https://%s.foo.bar", name) - if ep.Endpoint != targetName { - t.Errorf("Expected Endpoint '%s', received '%s' for case %s", - targetName, ep.Endpoint, test.name) - } - - fn := loadCustomResolver() - ep1, e := fn(name, "sregion", nil) - if e != nil { - t.Errorf("Expected a valid endpoint for %s in case %s", - name, test.name) - } else { - targetName := fmt.Sprintf("https://%s.foo.bar", name) - if ep1.URL != targetName { - t.Errorf("Expected endpoint url: %s, received %s in case %s", - targetName, ep1.URL, test.name) - } - if ep1.SigningRegion != "sregion" { - t.Errorf("Expected signing region 'sregion', received '%s' in case %s", - ep1.SigningRegion, test.name) - } + if ep1.SigningRegion != "sregion" { + t.Errorf("Expected signing region 'sregion', received '%s' in case %s", + ep1.SigningRegion, test.name) } } }