diff --git a/staging/src/k8s.io/apimachinery/go.mod b/staging/src/k8s.io/apimachinery/go.mod index 2d6c2af7611..4531a14c516 100644 --- a/staging/src/k8s.io/apimachinery/go.mod +++ b/staging/src/k8s.io/apimachinery/go.mod @@ -20,6 +20,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.1 golang.org/x/net v0.7.0 + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 gopkg.in/inf.v0 v0.9.1 k8s.io/klog/v2 v2.80.1 k8s.io/kube-openapi v0.0.0-20230123231816-1cb3ae25d79a diff --git a/staging/src/k8s.io/apimachinery/go.sum b/staging/src/k8s.io/apimachinery/go.sum index ffd14f2c4d6..6c59e28f7fb 100644 --- a/staging/src/k8s.io/apimachinery/go.sum +++ b/staging/src/k8s.io/apimachinery/go.sum @@ -138,6 +138,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/staging/src/k8s.io/apimachinery/pkg/util/waitgroup/ratelimited_waitgroup.go b/staging/src/k8s.io/apimachinery/pkg/util/waitgroup/ratelimited_waitgroup.go new file mode 100644 index 00000000000..8766390fc2e --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/waitgroup/ratelimited_waitgroup.go @@ -0,0 +1,134 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package waitgroup + +import ( + "context" + "fmt" + "sync" +) + +// RateLimiter abstracts the rate limiter used by RateLimitedSafeWaitGroup. +// The implementation must be thread-safe. +type RateLimiter interface { + Wait(ctx context.Context) error +} + +// RateLimiterFactoryFunc is used by the RateLimitedSafeWaitGroup to create a new +// instance of a RateLimiter that will be used to rate limit the return rate +// of the active number of request(s). 'count' is the number of requests in +// flight that are expected to invoke 'Done' on this wait group. +type RateLimiterFactoryFunc func(count int) (RateLimiter, context.Context, context.CancelFunc) + +// RateLimitedSafeWaitGroup must not be copied after first use. +type RateLimitedSafeWaitGroup struct { + wg sync.WaitGroup + // Once Wait is initiated, all consecutive Done invocation will be + // rate limited using this rate limiter. + limiter RateLimiter + stopCtx context.Context + + mu sync.Mutex + // wait indicate whether Wait is called, if true, + // then any Add with positive delta will return error. + wait bool + // number of request(s) currently using the wait group + count int +} + +// Add adds delta, which may be negative, similar to sync.WaitGroup. +// If Add with a positive delta happens after Wait, it will return error, +// which prevent unsafe Add. +func (wg *RateLimitedSafeWaitGroup) Add(delta int) error { + wg.mu.Lock() + defer wg.mu.Unlock() + + if wg.wait && delta > 0 { + return fmt.Errorf("add with positive delta after Wait is forbidden") + } + wg.wg.Add(delta) + wg.count += delta + return nil +} + +// Done decrements the WaitGroup counter, rate limiting is applied only +// when the wait group is in waiting mode. +func (wg *RateLimitedSafeWaitGroup) Done() { + var limiter RateLimiter + func() { + wg.mu.Lock() + defer wg.mu.Unlock() + + wg.count -= 1 + if wg.wait { + // we are using the limiter outside the scope of the lock + limiter = wg.limiter + } + }() + + defer wg.wg.Done() + if limiter != nil { + limiter.Wait(wg.stopCtx) + } +} + +// Wait blocks until the WaitGroup counter is zero or a hard limit has elapsed. +// It returns the number of active request(s) accounted for at the time Wait +// has been invoked, number of request(s) that have drianed (done using the +// wait group immediately before Wait returns). +// Ideally, the both numbers returned should be equal, to indicate that all +// request(s) using the wait group have released their lock. +func (wg *RateLimitedSafeWaitGroup) Wait(limiterFactory RateLimiterFactoryFunc) (int, int, error) { + if limiterFactory == nil { + return 0, 0, fmt.Errorf("rate limiter factory must be specified") + } + + var cancel context.CancelFunc + var countNow, countAfter int + func() { + wg.mu.Lock() + defer wg.mu.Unlock() + + wg.limiter, wg.stopCtx, cancel = limiterFactory(wg.count) + countNow = wg.count + wg.wait = true + }() + + defer cancel() + // there should be a hard stop, in case request(s) are not responsive + // enough to invoke Done before the grace period is over. + waitDoneCh := make(chan struct{}) + go func() { + defer close(waitDoneCh) + wg.wg.Wait() + }() + + var err error + select { + case <-wg.stopCtx.Done(): + err = wg.stopCtx.Err() + case <-waitDoneCh: + } + + func() { + wg.mu.Lock() + defer wg.mu.Unlock() + + countAfter = wg.count + }() + return countNow, countAfter, err +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/waitgroup/ratelimited_waitgroup_test.go b/staging/src/k8s.io/apimachinery/pkg/util/waitgroup/ratelimited_waitgroup_test.go new file mode 100644 index 00000000000..dcaefe44a46 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/waitgroup/ratelimited_waitgroup_test.go @@ -0,0 +1,320 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package waitgroup + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "golang.org/x/time/rate" + "k8s.io/apimachinery/pkg/util/wait" +) + +func TestRateLimitedSafeWaitGroup(t *testing.T) { + // we want to keep track of how many times rate limiter Wait method is + // being invoked, both before and after the wait group is in waiting mode. + limiter := &limiterWrapper{} + + // we expect the context passed by the factory to be used + var cancelInvoked int + factory := &factory{ + limiter: limiter, + grace: 2 * time.Second, + ctx: context.Background(), + cancel: func() { + cancelInvoked++ + }, + } + target := &rateLimitedSafeWaitGroupWrapper{ + RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{limiter: limiter}, + } + + // two set of requests + // - n1: this set will finish using this waitgroup before Wait is invoked + // - n2: this set will be in flight after Wait is invoked + n1, n2 := 100, 101 + + // so we know when all requests in n1 are done using the waitgroup + n1DoneWG := sync.WaitGroup{} + + // so we know when all requests in n2 have called Add, + // but not finished with the waitgroup yet. + // this will allow the test to invoke 'Wait' once all requests + // in n2 have called `Add`, but none has called `Done` yet. + n2BeforeWaitWG := sync.WaitGroup{} + // so we know when all requests in n2 have called Done and + // are finished using the waitgroup + n2DoneWG := sync.WaitGroup{} + + startCh, blockedCh := make(chan struct{}), make(chan struct{}) + n1DoneWG.Add(n1) + for i := 0; i < n1; i++ { + go func() { + defer n1DoneWG.Done() + <-startCh + + target.Add(1) + // let's finish using the waitgroup immediately + target.Done() + }() + } + + n2BeforeWaitWG.Add(n2) + n2DoneWG.Add(n2) + for i := 0; i < n2; i++ { + go func() { + func() { + defer n2BeforeWaitWG.Done() + <-startCh + + target.Add(1) + }() + + func() { + defer n2DoneWG.Done() + // let's wait for the test to instruct the requests in n2 + // that it is time to finish using the waitgroup. + <-blockedCh + + target.Done() + }() + }() + } + + // initially the count should be zero + if count := target.Count(); count != 0 { + t.Errorf("expected count to be zero, but got: %d", count) + } + // start the test + close(startCh) + // wait for the first set of requests (n1) to be done + n1DoneWG.Wait() + + // after the first set of requests (n1) are done, the count should be zero + if invoked := limiter.invoked(); invoked != 0 { + t.Errorf("expected no call to rate limiter before Wait is called, but got: %d", invoked) + } + + // make sure all requetss in the second group (n2) have started using the + // waitgroup (Add invoked) but no request is done using the waitgroup yet. + n2BeforeWaitWG.Wait() + + // count should be n2, since every request in n2 is still using the waitgroup + if count := target.Count(); count != n2 { + t.Errorf("expected count to be: %d, but got: %d", n2, count) + } + + // time for us to mark the waitgroup as `Waiting` + waitDoneCh := make(chan waitResult) + go func() { + factory.grace = 2 * time.Second + before, after, err := target.Wait(factory.NewRateLimiter) + waitDoneCh <- waitResult{before: before, after: after, err: err} + }() + + // make sure there is no flake in the test due to this race condition + var waitingGot bool + wait.PollImmediate(500*time.Millisecond, wait.ForeverTestTimeout, func() (done bool, err error) { + if waiting := target.Waiting(); waiting { + waitingGot = true + return true, nil + } + return false, nil + }) + // verify that the waitgroup is in 'Waiting' mode + if !waitingGot { + t.Errorf("expected to be in waiting") + } + + // we should not allow any new request to use this waitgroup any longer + if err := target.Add(1); err == nil || + !strings.Contains(err.Error(), "add with positive delta after Wait is forbidden") { + t.Errorf("expected Add to return error while in waiting mode: %v", err) + } + + // make sure that RateLimitedSafeWaitGroup passes the right + // request count to the limiter factory. + if factory.countGot != n2 { + t.Errorf("expected count passed to factory to be: %d, but got: %d", n2, factory.countGot) + } + + // indicate to all requests (each request in n2) that are + // currently using this waitgroup that they can go ahead + // and invoke 'Done' to finish using this waitgroup. + close(blockedCh) + n2DoneWG.Wait() + + if invoked := limiter.invoked(); invoked != n2 { + t.Errorf("expected rate limiter to be called %d times, but got: %d", n2, invoked) + } + + waitResult := <-waitDoneCh + if count := target.Count(); count != 0 { + t.Errorf("expected count to be zero, but got: %d", count) + } + if waitResult.before != n2 { + t.Errorf("expected count before Wait to be: %d, but got: %d", n2, waitResult.before) + } + if waitResult.after != 0 { + t.Errorf("expected count after Wait to be zero, but got: %d", waitResult.after) + } + if cancelInvoked != 1 { + t.Errorf("expected context cancel to be invoked once, but got: %d", cancelInvoked) + } +} + +func TestRateLimitedSafeWaitGroupWithHardTimeout(t *testing.T) { + target := &rateLimitedSafeWaitGroupWrapper{ + RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{}, + } + n := 10 + wg := sync.WaitGroup{} + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + target.Add(1) + }() + } + + wg.Wait() + if count := target.Count(); count != n { + t.Errorf("expected count to be: %d, but got: %d", n, count) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + activeAt, activeNow, err := target.Wait(func(count int) (RateLimiter, context.Context, context.CancelFunc) { + return nil, ctx, cancel + }) + if activeAt != n { + t.Errorf("expected active at Wait to be: %d, but got: %d", n, activeAt) + } + if activeNow != n { + t.Errorf("expected active after Wait to be: %d, but got: %d", n, activeNow) + } + if err != context.Canceled { + t.Errorf("expected error: %v, but got: %v", context.Canceled, err) + } +} + +func TestRateLimitedSafeWaitGroupWithBurstOfOne(t *testing.T) { + target := &rateLimitedSafeWaitGroupWrapper{ + RateLimitedSafeWaitGroup: &RateLimitedSafeWaitGroup{}, + } + n := 200 + grace := 5 * time.Second + wg := sync.WaitGroup{} + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + target.Add(1) + }() + } + wg.Wait() + + waitingCh := make(chan struct{}) + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + + <-waitingCh + target.Done() + }() + } + defer wg.Wait() + + now := time.Now() + t.Logf("Wait starting, N=%d, grace: %s, at: %s", n, grace, now) + activeAt, activeNow, err := target.Wait(func(count int) (RateLimiter, context.Context, context.CancelFunc) { + defer close(waitingCh) + // no deadline in context, Wait will wait forever, we want to measure + // how long it takes for the requests to drain. + return rate.NewLimiter(rate.Limit(n/int(grace.Seconds())), 1), context.Background(), func() {} + }) + took := time.Since(now) + t.Logf("Wait finished, count(before): %d, count(after): %d, took: %s, err: %v", activeAt, activeNow, took, err) + + // in CPU starved environment, the go routines may not finish in time + if took > 2*grace { + t.Errorf("expected Wait to take: %s, but it took: %s", grace, took) + } +} + +type waitResult struct { + before, after int + err error +} + +type rateLimitedSafeWaitGroupWrapper struct { + *RateLimitedSafeWaitGroup +} + +// used by test only +func (wg *rateLimitedSafeWaitGroupWrapper) Count() int { + wg.mu.Lock() + defer wg.mu.Unlock() + + return wg.count +} +func (wg *rateLimitedSafeWaitGroupWrapper) Waiting() bool { + wg.mu.Lock() + defer wg.mu.Unlock() + + return wg.wait +} + +type limiterWrapper struct { + delegate RateLimiter + lock sync.Mutex + invokedN int +} + +func (w *limiterWrapper) invoked() int { + w.lock.Lock() + defer w.lock.Unlock() + return w.invokedN +} +func (w *limiterWrapper) Wait(ctx context.Context) error { + w.lock.Lock() + w.invokedN++ + w.lock.Unlock() + + if w.delegate != nil { + w.delegate.Wait(ctx) + } + return nil +} + +type factory struct { + limiter *limiterWrapper + grace time.Duration + ctx context.Context + cancel context.CancelFunc + countGot int +} + +func (f *factory) NewRateLimiter(count int) (RateLimiter, context.Context, context.CancelFunc) { + f.countGot = count + f.limiter.delegate = rate.NewLimiter(rate.Limit(count/int(f.grace.Seconds())), 20) + return f.limiter, f.ctx, f.cancel +}