From b346cedefb1a27227e270495ab4ff82a06c3a0a6 Mon Sep 17 00:00:00 2001 From: Prateek Gogia Date: Mon, 20 Jun 2022 17:35:35 -0500 Subject: [PATCH] Add rate limiting when calling STS assume role API --- .../k8s.io/legacy-cloud-providers/aws/aws.go | 2 +- .../aws/aws_assumerole_provider.go | 62 ++++++++ .../aws/aws_assumerole_provider_test.go | 132 ++++++++++++++++++ 3 files changed, 195 insertions(+), 1 deletion(-) create mode 100644 staging/src/k8s.io/legacy-cloud-providers/aws/aws_assumerole_provider.go create mode 100644 staging/src/k8s.io/legacy-cloud-providers/aws/aws_assumerole_provider_test.go diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go index 541f741cfd1..d7bed4bd924 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws.go @@ -1207,7 +1207,7 @@ func init() { creds = credentials.NewChainCredentials( []credentials.Provider{ &credentials.EnvProvider{}, - provider, + assumeRoleProvider(provider), }) } diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_assumerole_provider.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_assumerole_provider.go new file mode 100644 index 00000000000..ad5a63b4c7f --- /dev/null +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_assumerole_provider.go @@ -0,0 +1,62 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" +) + +const ( + invalidateCredsAfter = 1 * time.Second +) + +// assumeRoleProviderWithRateLimiting makes sure we call the underlying provider only +// once after `invalidateCredsAfter` period +type assumeRoleProviderWithRateLimiting struct { + provider credentials.Provider + invalidateCredsAfter time.Duration + sync.RWMutex + lastError error + lastValue credentials.Value + lastRetrieveTime time.Time +} + +func assumeRoleProvider(provider credentials.Provider) credentials.Provider { + return &assumeRoleProviderWithRateLimiting{provider: provider, + invalidateCredsAfter: invalidateCredsAfter} +} + +func (l *assumeRoleProviderWithRateLimiting) Retrieve() (credentials.Value, error) { + l.Lock() + defer l.Unlock() + if time.Since(l.lastRetrieveTime) < l.invalidateCredsAfter { + if l.lastError != nil { + return credentials.Value{}, l.lastError + } + return l.lastValue, nil + } + l.lastValue, l.lastError = l.provider.Retrieve() + l.lastRetrieveTime = time.Now() + return l.lastValue, l.lastError +} + +func (l *assumeRoleProviderWithRateLimiting) IsExpired() bool { + return l.provider.IsExpired() +} diff --git a/staging/src/k8s.io/legacy-cloud-providers/aws/aws_assumerole_provider_test.go b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_assumerole_provider_test.go new file mode 100644 index 00000000000..09e947a4a11 --- /dev/null +++ b/staging/src/k8s.io/legacy-cloud-providers/aws/aws_assumerole_provider_test.go @@ -0,0 +1,132 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "fmt" + "reflect" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" +) + +func Test_assumeRoleProviderWithRateLimiting_Retrieve(t *testing.T) { + type fields struct { + provider credentials.Provider + invalidateCredsAfter time.Duration + RWMutex sync.RWMutex + lastError error + lastValue credentials.Value + lastRetrieveTime time.Time + } + tests := []struct { + name string + fields fields + want credentials.Value + wantProviderCalled bool + sleepBeforeCallingProvider time.Duration + wantErr bool + wantErrString string + }{{ + name: "Call assume role provider and verify access ID returned", + fields: fields{provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"}}, + want: credentials.Value{AccessKeyID: "fakeID"}, + wantProviderCalled: true, + }, { + name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last value", + fields: fields{ + provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID"}, + invalidateCredsAfter: 100 * time.Millisecond, + lastValue: credentials.Value{AccessKeyID: "fakeID1"}, + lastRetrieveTime: time.Now(), + }, + want: credentials.Value{AccessKeyID: "fakeID1"}, + wantProviderCalled: false, + sleepBeforeCallingProvider: 10 * time.Millisecond, + }, { + name: "Assume role provider returns an error when trying to assume a role", + fields: fields{ + provider: &fakeAssumeRoleProvider{err: fmt.Errorf("can't assume fake role")}, + invalidateCredsAfter: 10 * time.Millisecond, + lastRetrieveTime: time.Now(), + }, + wantProviderCalled: true, + wantErr: true, + wantErrString: "can't assume fake role", + sleepBeforeCallingProvider: 15 * time.Millisecond, + }, { + name: "Immediate call to assume role API, shouldn't call the underlying provider and return the last error value", + fields: fields{ + provider: &fakeAssumeRoleProvider{}, + invalidateCredsAfter: 100 * time.Millisecond, + lastRetrieveTime: time.Now(), + }, + want: credentials.Value{}, + wantProviderCalled: false, + wantErr: true, + wantErrString: "can't assume fake role", + }, { + name: "Delayed call to assume role API, should call the underlying provider", + fields: fields{ + provider: &fakeAssumeRoleProvider{accesskeyID: "fakeID2"}, + invalidateCredsAfter: 20 * time.Millisecond, + lastRetrieveTime: time.Now(), + }, + want: credentials.Value{AccessKeyID: "fakeID2"}, + wantProviderCalled: true, + sleepBeforeCallingProvider: 25 * time.Millisecond, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := &assumeRoleProviderWithRateLimiting{ + provider: tt.fields.provider, + invalidateCredsAfter: tt.fields.invalidateCredsAfter, + lastError: tt.fields.lastError, + lastValue: tt.fields.lastValue, + lastRetrieveTime: tt.fields.lastRetrieveTime, + } + time.Sleep(tt.sleepBeforeCallingProvider) + got, err := l.Retrieve() + if (err != nil) != tt.wantErr && (tt.wantErr && reflect.DeepEqual(err, tt.wantErrString)) { + t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("assumeRoleProviderWithRateLimiting.Retrieve() got = %v, want %v", got, tt.want) + return + } + if tt.wantProviderCalled != tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled { + t.Errorf("provider called %v, want %v", tt.fields.provider.(*fakeAssumeRoleProvider).providerCalled, tt.wantProviderCalled) + } + }) + } +} + +type fakeAssumeRoleProvider struct { + accesskeyID string + err error + providerCalled bool +} + +func (f *fakeAssumeRoleProvider) Retrieve() (credentials.Value, error) { + f.providerCalled = true + return credentials.Value{AccessKeyID: f.accesskeyID}, f.err +} + +func (f *fakeAssumeRoleProvider) IsExpired() bool { return true }