mirror of
https://github.com/rancher/steve.git
synced 2025-09-03 08:25:13 +00:00
Rework around a concurrency limiter.
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
"github.com/rancher/wrangler/v3/pkg/summary/client"
|
||||
"github.com/rancher/wrangler/v3/pkg/summary/informer"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
"k8s.io/apimachinery/pkg/api/meta"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
@@ -129,6 +127,11 @@ func (h *clusterCache) addResourceEventHandler(gvk schema2.GroupVersionKind, inf
|
||||
}
|
||||
|
||||
func (h *clusterCache) OnSchemas(schemas *schema.Collection) error {
|
||||
now := time.Now()
|
||||
defer func() {
|
||||
logrus.Debugf("steve: clusterCache.OnSchemas took %v", time.Since(now))
|
||||
}()
|
||||
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
@@ -137,7 +140,6 @@ func (h *clusterCache) OnSchemas(schemas *schema.Collection) error {
|
||||
toWait []*watcher
|
||||
)
|
||||
|
||||
rateLimited, limiter := rateLimiterFromEnvironment()
|
||||
for _, id := range schemas.IDs() {
|
||||
schema := schemas.Schema(id)
|
||||
if !validSchema(schema) {
|
||||
@@ -167,15 +169,6 @@ func (h *clusterCache) OnSchemas(schemas *schema.Collection) error {
|
||||
|
||||
logrus.Infof("Watching metadata for %s", w.gvk)
|
||||
h.addResourceEventHandler(w.gvk, w.informer)
|
||||
go func() {
|
||||
if rateLimited {
|
||||
logrus.Debug("steve: client cache rate-limiting enabled")
|
||||
if err := limiter.Wait(ctx); err != nil {
|
||||
logrus.Errorf("error waiting to query")
|
||||
}
|
||||
}
|
||||
w.informer.Run(w.ctx.Done())
|
||||
}()
|
||||
}
|
||||
|
||||
for gvk, w := range h.watchers {
|
||||
@@ -186,18 +179,29 @@ func (h *clusterCache) OnSchemas(schemas *schema.Collection) error {
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Debugf("steve: clusterCache.OnSchemas requires %v watchers", len(toWait))
|
||||
|
||||
limiter := limiterFromEnvironment()
|
||||
for _, w := range toWait {
|
||||
ctx, cancel := context.WithTimeout(w.ctx, 15*time.Minute)
|
||||
if !cache.WaitForCacheSync(ctx.Done(), w.informer.HasSynced) {
|
||||
logrus.Errorf("failed to sync cache for %v", w.gvk)
|
||||
limiter.Execute(w.ctx, func(ctx context.Context) error {
|
||||
go func() {
|
||||
w.informer.Run(w.ctx.Done())
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(w.ctx, 15*time.Minute)
|
||||
if !cache.WaitForCacheSync(ctx.Done(), w.informer.HasSynced) {
|
||||
logrus.Errorf("failed to sync cache for %v", w.gvk)
|
||||
cancel()
|
||||
w.cancel()
|
||||
delete(h.watchers, w.gvk)
|
||||
}
|
||||
cancel()
|
||||
w.cancel()
|
||||
delete(h.watchers, w.gvk)
|
||||
}
|
||||
cancel()
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
return limiter.Wait()
|
||||
}
|
||||
|
||||
func (h *clusterCache) Get(gvk schema2.GroupVersionKind, namespace, name string) (interface{}, bool, error) {
|
||||
@@ -311,32 +315,18 @@ func callAll(handlers []interface{}, gvr schema2.GroupVersionKind, key string, o
|
||||
return obj, merr.NewErrors(errs...)
|
||||
}
|
||||
|
||||
func rateLimiterFromEnvironment() (bool, *rate.Limiter) {
|
||||
rateLimited := strings.ToLower(os.Getenv("RANCHER_CACHE_RATELIMIT")) == "true"
|
||||
if !rateLimited {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var qps float32 = 10000.0
|
||||
if v := os.Getenv("RANCHER_CACHE_CLIENT_QPS"); v != "" {
|
||||
parsed, err := strconv.ParseFloat(v, 32)
|
||||
if err != nil {
|
||||
logrus.Infof("steve: configuring client failed to parse RANCHER_CACHE_CLIENT_QPS: %s", err)
|
||||
} else {
|
||||
qps = float32(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
burst := 100
|
||||
if v := os.Getenv("RANCHER_CACHE_CLIENT_BURST"); v != "" {
|
||||
func limiterFromEnvironment() *Limiter {
|
||||
var limit int = 100
|
||||
if v := os.Getenv("RANCHER_CACHE_CLIENT_LIMIT"); v != "" {
|
||||
parsed, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
logrus.Infof("steve: configuring cache client failed to parse RANCHER_CACHE_CLIENT_BURST: %s", err)
|
||||
logrus.Infof("steve: configuring cache client failed to parse RANCHER_CACHE_CLIENT_LIMIT: %s", err)
|
||||
} else {
|
||||
burst = parsed
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Infof("steve: configuring client cache QPS = %v, burst = %v, ratelimiting = %v", qps, burst, rateLimited)
|
||||
return rateLimited, rate.NewLimiter(rate.Limit(qps), burst)
|
||||
logrus.Debugf("steve: configuring client cache limiter: %v", limit)
|
||||
|
||||
return NewLimiter(limit)
|
||||
}
|
||||
|
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -85,6 +84,8 @@ func TestClusterCacheRateLimitingNotEnabled(t *testing.T) {
|
||||
// (list and watch).
|
||||
// This configures the test-server to rate-limit responses.
|
||||
requestCount := len(testSchema) * 2
|
||||
// The cache spawns a Go routine that makes 2 requests for each schema
|
||||
// (list and watch).
|
||||
tstSrv := startTestServer(t, rate.NewLimiter(rate.Limit(1000),
|
||||
requestCount-2), &errorCount)
|
||||
|
||||
@@ -111,18 +112,12 @@ func TestClusterCacheRateLimitingEnabled(t *testing.T) {
|
||||
var errorCount int32
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
defer cancel()
|
||||
t.Setenv("RANCHER_CACHE_RATELIMIT", "true")
|
||||
t.Setenv("RANCHER_CACHE_CLIENT_QPS", "10000")
|
||||
t.Setenv("RANCHER_CACHE_CLIENT_LIMIT", "3")
|
||||
requestCount := len(testSchema) * 2
|
||||
|
||||
// This configures the cache-client burst to less than the number of
|
||||
// requests we'll make.
|
||||
t.Setenv("RANCHER_CACHE_CLIENT_BURST", strconv.Itoa(requestCount-4))
|
||||
|
||||
// The cache makes spawns a Go routine that makes 2 requests for each schema
|
||||
// The cache spawns a Go routine that makes 2 requests for each schema
|
||||
// (list and watch).
|
||||
tstSrv := startTestServer(t, rate.NewLimiter(rate.Limit(1000),
|
||||
requestCount), &errorCount)
|
||||
requestCount-2), &errorCount)
|
||||
|
||||
sf := schema.NewCollection(ctx, &types.APISchemas{}, fakeAccessSetLookup{})
|
||||
sf.Reset(testSchema)
|
||||
@@ -170,7 +165,7 @@ func startTestServer(t *testing.T, rl *rate.Limiter, errors *int32) *httptest.Se
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
values := r.URL.Query()
|
||||
watched := values.Get("watch") == "true"
|
||||
t.Logf("Faking response to %s %v", r.URL.Path, watched)
|
||||
t.Logf("Faking response to %s watch=%v", r.URL.Path, watched)
|
||||
|
||||
if !rl.Allow() {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
|
68
pkg/clustercache/limiter.go
Normal file
68
pkg/clustercache/limiter.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package clustercache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Limiter limits the number of concurrent calls to an external service
|
||||
// and collects errors from failed calls.
|
||||
type Limiter struct {
|
||||
semaphore chan struct{}
|
||||
|
||||
// Mutex to protect the error
|
||||
// errors.Join is _not_ Go routine safe
|
||||
err error
|
||||
mu sync.Mutex
|
||||
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewLimiter creates a new Limiter with the specified concurrency limit.
|
||||
func NewLimiter(limit int) *Limiter {
|
||||
if limit <= 0 {
|
||||
limit = 1
|
||||
}
|
||||
return &Limiter{
|
||||
semaphore: make(chan struct{}, limit),
|
||||
}
|
||||
}
|
||||
|
||||
// Execute executes the given function (representing an external service call)
|
||||
// while respecting the concurrency limit. If the function returns an error,
|
||||
// it is collected by the Limiter.
|
||||
func (sl *Limiter) Execute(ctx context.Context, serviceFunc func(ctx context.Context) error) {
|
||||
sl.wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer sl.wg.Done()
|
||||
|
||||
select {
|
||||
case sl.semaphore <- struct{}{}:
|
||||
defer func() {
|
||||
<-sl.semaphore
|
||||
}()
|
||||
|
||||
err := serviceFunc(ctx)
|
||||
if err != nil {
|
||||
sl.mu.Lock()
|
||||
sl.err = errors.Join(sl.err, err)
|
||||
sl.mu.Unlock()
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// If the context is cancelled before acquiring a semaphore slot,
|
||||
// just exit the goroutine. No error needs to be recorded for this specific case,
|
||||
// as the call wasn't even initiated.
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait blocks until all outstanding calls made via `Execute` have completed
|
||||
// and returns a slice of all collected errors.
|
||||
func (sl *Limiter) Wait() error {
|
||||
sl.wg.Wait()
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
return sl.err
|
||||
}
|
188
pkg/clustercache/limiter_test.go
Normal file
188
pkg/clustercache/limiter_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package clustercache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewLimiter(t *testing.T) {
|
||||
t.Run("PositiveLimit", func(t *testing.T) {
|
||||
limiter := NewLimiter(5)
|
||||
assert.NotNil(t, limiter)
|
||||
assert.Len(t, limiter.semaphore, 0)
|
||||
assert.Equal(t, 5, cap(limiter.semaphore))
|
||||
})
|
||||
|
||||
t.Run("ZeroLimit", func(t *testing.T) {
|
||||
limiter := NewLimiter(0)
|
||||
assert.NotNil(t, limiter)
|
||||
assert.Equal(t, 1, cap(limiter.semaphore))
|
||||
})
|
||||
|
||||
t.Run("NegativeLimit", func(t *testing.T) {
|
||||
limiter := NewLimiter(-3)
|
||||
assert.NotNil(t, limiter)
|
||||
assert.Equal(t, 1, cap(limiter.semaphore))
|
||||
})
|
||||
}
|
||||
|
||||
func TestLimiter_Concurrency(t *testing.T) {
|
||||
totalExecutions := 10
|
||||
|
||||
var activeExecutions int32
|
||||
var maxActiveExecutions int32
|
||||
|
||||
limit := 3
|
||||
activeTracker := make(chan struct{}, limit)
|
||||
|
||||
testCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
limiter := NewLimiter(limit)
|
||||
for i := 0; i < totalExecutions; i++ {
|
||||
limiter.Execute(testCtx, func(ctx context.Context) error {
|
||||
select {
|
||||
case activeTracker <- struct{}{}:
|
||||
currentActive := atomic.AddInt32(&activeExecutions, 1)
|
||||
defer atomic.AddInt32(&activeExecutions, -1)
|
||||
|
||||
for {
|
||||
oldMax := atomic.LoadInt32(&maxActiveExecutions)
|
||||
if currentActive > oldMax {
|
||||
if atomic.CompareAndSwapInt32(&maxActiveExecutions, oldMax, currentActive) {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
<-activeTracker
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
err := limiter.Wait()
|
||||
|
||||
assert.Equal(t, int32(limit), maxActiveExecutions)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLimiter_ErrorCollection(t *testing.T) {
|
||||
limiter := NewLimiter(5)
|
||||
expectedErrors := 0
|
||||
var startedCount int32
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Execution 1: Success
|
||||
limiter.Execute(ctx, func(callCtx context.Context) error {
|
||||
return fakeService(callCtx, 1, false, 10*time.Millisecond, &startedCount)
|
||||
})
|
||||
|
||||
// Execution 2: Failure
|
||||
limiter.Execute(ctx, func(callCtx context.Context) error {
|
||||
expectedErrors++
|
||||
return fakeService(callCtx, 2, true, 10*time.Millisecond, &startedCount)
|
||||
})
|
||||
|
||||
// Execution 3: Success
|
||||
limiter.Execute(ctx, func(callCtx context.Context) error {
|
||||
return fakeService(callCtx, 3, false, 10*time.Millisecond, &startedCount)
|
||||
})
|
||||
|
||||
// Execution 4: Failure
|
||||
limiter.Execute(ctx, func(callCtx context.Context) error {
|
||||
expectedErrors++
|
||||
return fakeService(callCtx, 4, true, 10*time.Millisecond, &startedCount)
|
||||
})
|
||||
|
||||
err := limiter.Wait()
|
||||
|
||||
assert.Equal(t, int32(4), startedCount)
|
||||
assert.ErrorContains(t, err, "error from service 2")
|
||||
assert.ErrorContains(t, err, "error from service 4")
|
||||
}
|
||||
|
||||
func TestLimiter_ContextCancellation(t *testing.T) {
|
||||
t.Run("cancellation during execution", func(t *testing.T) {
|
||||
limiter := NewLimiter(2)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) // Short timeout
|
||||
defer cancel()
|
||||
|
||||
var startedCount int32
|
||||
|
||||
// This call should be running when context is cancelled
|
||||
limiter.Execute(ctx, func(callCtx context.Context) error {
|
||||
return fakeService(callCtx, 1, false, 200*time.Millisecond, &startedCount)
|
||||
})
|
||||
|
||||
// This call should complete before cancellation
|
||||
limiter.Execute(ctx, func(callCtx context.Context) error {
|
||||
return fakeService(callCtx, 2, false, 10*time.Millisecond, &startedCount)
|
||||
})
|
||||
|
||||
err := limiter.Wait()
|
||||
|
||||
assert.ErrorContains(t, err, "service 1 cancelled: context deadline exceeded")
|
||||
assert.Equal(t, int32(2), startedCount)
|
||||
})
|
||||
|
||||
t.Run("cancellation before acquiring semaphore", func(t *testing.T) {
|
||||
limiter := NewLimiter(1)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
var startedCount int32
|
||||
|
||||
// Execute 1: Will acquire semaphore and run (and likely be cancelled)
|
||||
limiter.Execute(ctx, func(callCtx context.Context) error {
|
||||
return fakeService(callCtx, 1, false, 100*time.Millisecond, &startedCount)
|
||||
})
|
||||
|
||||
// Execute 2: Will wait for semaphore, but context will likely cancel before it acquires
|
||||
limiter.Execute(ctx, func(callCtx context.Context) error {
|
||||
return fakeService(callCtx, 2, false, 200*time.Millisecond, &startedCount)
|
||||
})
|
||||
|
||||
err := limiter.Wait()
|
||||
|
||||
// Assert that only the first call actually started its service function
|
||||
// (the second one was cancelled before acquiring the semaphore slot).
|
||||
assert.Equal(t, int32(1), startedCount, "Only the first call should have started")
|
||||
|
||||
// The first call might be cancelled, leading to 1 error.
|
||||
// The second call's goroutine will exit via `<-ctx.Done()` *before* acquiring the semaphore,
|
||||
// and thus won't add an error to the `limiter.err`
|
||||
assert.ErrorContains(t, err, "service 2 cancelled: context deadline exceeded")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLimiter_NoExecutions(t *testing.T) {
|
||||
limiter := NewLimiter(3)
|
||||
err := limiter.Wait()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func fakeService(ctx context.Context, id int, simulateError bool, duration time.Duration, startedCounter *int32) error {
|
||||
atomic.AddInt32(startedCounter, 1)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("service %d cancelled: %v", id, ctx.Err())
|
||||
case <-time.After(duration):
|
||||
if simulateError {
|
||||
return fmt.Errorf("error from service %d", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user