From a621cbfd16be21d73e0fc29bddc75e46cb639d2d Mon Sep 17 00:00:00 2001 From: Kevin McDermott Date: Fri, 20 Jun 2025 12:37:31 +0100 Subject: [PATCH] Rework around a concurrency limiter. --- pkg/clustercache/controller.go | 74 +++++------ pkg/clustercache/controller_test.go | 17 +-- pkg/clustercache/limiter.go | 68 ++++++++++ pkg/clustercache/limiter_test.go | 188 ++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 53 deletions(-) create mode 100644 pkg/clustercache/limiter.go create mode 100644 pkg/clustercache/limiter_test.go diff --git a/pkg/clustercache/controller.go b/pkg/clustercache/controller.go index 274ae8e3..29c85406 100644 --- a/pkg/clustercache/controller.go +++ b/pkg/clustercache/controller.go @@ -4,7 +4,6 @@ import ( "context" "os" "strconv" - "strings" "sync" "time" @@ -15,7 +14,6 @@ import ( "github.com/rancher/wrangler/v3/pkg/summary/client" "github.com/rancher/wrangler/v3/pkg/summary/informer" "github.com/sirupsen/logrus" - "golang.org/x/time/rate" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -129,6 +127,11 @@ func (h *clusterCache) addResourceEventHandler(gvk schema2.GroupVersionKind, inf } func (h *clusterCache) OnSchemas(schemas *schema.Collection) error { + now := time.Now() + defer func() { + logrus.Debugf("steve: clusterCache.OnSchemas took %v", time.Since(now)) + }() + h.Lock() defer h.Unlock() @@ -137,7 +140,6 @@ func (h *clusterCache) OnSchemas(schemas *schema.Collection) error { toWait []*watcher ) - rateLimited, limiter := rateLimiterFromEnvironment() for _, id := range schemas.IDs() { schema := schemas.Schema(id) if !validSchema(schema) { @@ -167,15 +169,6 @@ func (h *clusterCache) OnSchemas(schemas *schema.Collection) error { logrus.Infof("Watching metadata for %s", w.gvk) h.addResourceEventHandler(w.gvk, w.informer) - go func() { - if rateLimited { - logrus.Debug("steve: client cache rate-limiting enabled") - if err := limiter.Wait(ctx); err != nil { - logrus.Errorf("error waiting to query") - } - } - w.informer.Run(w.ctx.Done()) - }() } for gvk, w := range h.watchers { @@ -186,18 +179,29 @@ func (h *clusterCache) OnSchemas(schemas *schema.Collection) error { } } + logrus.Debugf("steve: clusterCache.OnSchemas requires %v watchers", len(toWait)) + + limiter := limiterFromEnvironment() for _, w := range toWait { - ctx, cancel := context.WithTimeout(w.ctx, 15*time.Minute) - if !cache.WaitForCacheSync(ctx.Done(), w.informer.HasSynced) { - logrus.Errorf("failed to sync cache for %v", w.gvk) + limiter.Execute(w.ctx, func(ctx context.Context) error { + go func() { + w.informer.Run(w.ctx.Done()) + }() + + ctx, cancel := context.WithTimeout(w.ctx, 15*time.Minute) + if !cache.WaitForCacheSync(ctx.Done(), w.informer.HasSynced) { + logrus.Errorf("failed to sync cache for %v", w.gvk) + cancel() + w.cancel() + delete(h.watchers, w.gvk) + } cancel() - w.cancel() - delete(h.watchers, w.gvk) - } - cancel() + + return nil + }) } - return nil + return limiter.Wait() } func (h *clusterCache) Get(gvk schema2.GroupVersionKind, namespace, name string) (interface{}, bool, error) { @@ -311,32 +315,18 @@ func callAll(handlers []interface{}, gvr schema2.GroupVersionKind, key string, o return obj, merr.NewErrors(errs...) } -func rateLimiterFromEnvironment() (bool, *rate.Limiter) { - rateLimited := strings.ToLower(os.Getenv("RANCHER_CACHE_RATELIMIT")) == "true" - if !rateLimited { - return false, nil - } - - var qps float32 = 10000.0 - if v := os.Getenv("RANCHER_CACHE_CLIENT_QPS"); v != "" { - parsed, err := strconv.ParseFloat(v, 32) - if err != nil { - logrus.Infof("steve: configuring client failed to parse RANCHER_CACHE_CLIENT_QPS: %s", err) - } else { - qps = float32(parsed) - } - } - - burst := 100 - if v := os.Getenv("RANCHER_CACHE_CLIENT_BURST"); v != "" { +func limiterFromEnvironment() *Limiter { + var limit int = 100 + if v := os.Getenv("RANCHER_CACHE_CLIENT_LIMIT"); v != "" { parsed, err := strconv.Atoi(v) if err != nil { - logrus.Infof("steve: configuring cache client failed to parse RANCHER_CACHE_CLIENT_BURST: %s", err) + logrus.Infof("steve: configuring cache client failed to parse RANCHER_CACHE_CLIENT_LIMIT: %s", err) } else { - burst = parsed + limit = parsed } } - logrus.Infof("steve: configuring client cache QPS = %v, burst = %v, ratelimiting = %v", qps, burst, rateLimited) - return rateLimited, rate.NewLimiter(rate.Limit(qps), burst) + logrus.Debugf("steve: configuring client cache limiter: %v", limit) + + return NewLimiter(limit) } diff --git a/pkg/clustercache/controller_test.go b/pkg/clustercache/controller_test.go index 98667aec..2f5fd3af 100644 --- a/pkg/clustercache/controller_test.go +++ b/pkg/clustercache/controller_test.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "strconv" "sync/atomic" "testing" "time" @@ -85,6 +84,8 @@ func TestClusterCacheRateLimitingNotEnabled(t *testing.T) { // (list and watch). // This configures the test-server to rate-limit responses. requestCount := len(testSchema) * 2 + // The cache spawns a Go routine that makes 2 requests for each schema + // (list and watch). tstSrv := startTestServer(t, rate.NewLimiter(rate.Limit(1000), requestCount-2), &errorCount) @@ -111,18 +112,12 @@ func TestClusterCacheRateLimitingEnabled(t *testing.T) { var errorCount int32 ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - t.Setenv("RANCHER_CACHE_RATELIMIT", "true") - t.Setenv("RANCHER_CACHE_CLIENT_QPS", "10000") + t.Setenv("RANCHER_CACHE_CLIENT_LIMIT", "3") requestCount := len(testSchema) * 2 - - // This configures the cache-client burst to less than the number of - // requests we'll make. - t.Setenv("RANCHER_CACHE_CLIENT_BURST", strconv.Itoa(requestCount-4)) - - // The cache makes spawns a Go routine that makes 2 requests for each schema + // The cache spawns a Go routine that makes 2 requests for each schema // (list and watch). tstSrv := startTestServer(t, rate.NewLimiter(rate.Limit(1000), - requestCount), &errorCount) + requestCount-2), &errorCount) sf := schema.NewCollection(ctx, &types.APISchemas{}, fakeAccessSetLookup{}) sf.Reset(testSchema) @@ -170,7 +165,7 @@ func startTestServer(t *testing.T, rl *rate.Limiter, errors *int32) *httptest.Se ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { values := r.URL.Query() watched := values.Get("watch") == "true" - t.Logf("Faking response to %s %v", r.URL.Path, watched) + t.Logf("Faking response to %s watch=%v", r.URL.Path, watched) if !rl.Allow() { w.WriteHeader(http.StatusTooManyRequests) diff --git a/pkg/clustercache/limiter.go b/pkg/clustercache/limiter.go new file mode 100644 index 00000000..98d81ad1 --- /dev/null +++ b/pkg/clustercache/limiter.go @@ -0,0 +1,68 @@ +package clustercache + +import ( + "context" + "errors" + "sync" +) + +// Limiter limits the number of concurrent calls to an external service +// and collects errors from failed calls. +type Limiter struct { + semaphore chan struct{} + + // Mutex to protect the error + // errors.Join is _not_ Go routine safe + err error + mu sync.Mutex + + wg sync.WaitGroup +} + +// NewLimiter creates a new Limiter with the specified concurrency limit. +func NewLimiter(limit int) *Limiter { + if limit <= 0 { + limit = 1 + } + return &Limiter{ + semaphore: make(chan struct{}, limit), + } +} + +// Execute executes the given function (representing an external service call) +// while respecting the concurrency limit. If the function returns an error, +// it is collected by the Limiter. +func (sl *Limiter) Execute(ctx context.Context, serviceFunc func(ctx context.Context) error) { + sl.wg.Add(1) + + go func() { + defer sl.wg.Done() + + select { + case sl.semaphore <- struct{}{}: + defer func() { + <-sl.semaphore + }() + + err := serviceFunc(ctx) + if err != nil { + sl.mu.Lock() + sl.err = errors.Join(sl.err, err) + sl.mu.Unlock() + } + case <-ctx.Done(): + // If the context is cancelled before acquiring a semaphore slot, + // just exit the goroutine. No error needs to be recorded for this specific case, + // as the call wasn't even initiated. + } + }() +} + +// Wait blocks until all outstanding calls made via `Execute` have completed +// and returns a slice of all collected errors. +func (sl *Limiter) Wait() error { + sl.wg.Wait() + sl.mu.Lock() + defer sl.mu.Unlock() + return sl.err +} diff --git a/pkg/clustercache/limiter_test.go b/pkg/clustercache/limiter_test.go new file mode 100644 index 00000000..2d3a2e1e --- /dev/null +++ b/pkg/clustercache/limiter_test.go @@ -0,0 +1,188 @@ +package clustercache + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewLimiter(t *testing.T) { + t.Run("PositiveLimit", func(t *testing.T) { + limiter := NewLimiter(5) + assert.NotNil(t, limiter) + assert.Len(t, limiter.semaphore, 0) + assert.Equal(t, 5, cap(limiter.semaphore)) + }) + + t.Run("ZeroLimit", func(t *testing.T) { + limiter := NewLimiter(0) + assert.NotNil(t, limiter) + assert.Equal(t, 1, cap(limiter.semaphore)) + }) + + t.Run("NegativeLimit", func(t *testing.T) { + limiter := NewLimiter(-3) + assert.NotNil(t, limiter) + assert.Equal(t, 1, cap(limiter.semaphore)) + }) +} + +func TestLimiter_Concurrency(t *testing.T) { + totalExecutions := 10 + + var activeExecutions int32 + var maxActiveExecutions int32 + + limit := 3 + activeTracker := make(chan struct{}, limit) + + testCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + limiter := NewLimiter(limit) + for i := 0; i < totalExecutions; i++ { + limiter.Execute(testCtx, func(ctx context.Context) error { + select { + case activeTracker <- struct{}{}: + currentActive := atomic.AddInt32(&activeExecutions, 1) + defer atomic.AddInt32(&activeExecutions, -1) + + for { + oldMax := atomic.LoadInt32(&maxActiveExecutions) + if currentActive > oldMax { + if atomic.CompareAndSwapInt32(&maxActiveExecutions, oldMax, currentActive) { + break + } + } else { + break + } + } + time.Sleep(50 * time.Millisecond) + <-activeTracker + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + } + + err := limiter.Wait() + + assert.Equal(t, int32(limit), maxActiveExecutions) + assert.NoError(t, err) +} + +func TestLimiter_ErrorCollection(t *testing.T) { + limiter := NewLimiter(5) + expectedErrors := 0 + var startedCount int32 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Execution 1: Success + limiter.Execute(ctx, func(callCtx context.Context) error { + return fakeService(callCtx, 1, false, 10*time.Millisecond, &startedCount) + }) + + // Execution 2: Failure + limiter.Execute(ctx, func(callCtx context.Context) error { + expectedErrors++ + return fakeService(callCtx, 2, true, 10*time.Millisecond, &startedCount) + }) + + // Execution 3: Success + limiter.Execute(ctx, func(callCtx context.Context) error { + return fakeService(callCtx, 3, false, 10*time.Millisecond, &startedCount) + }) + + // Execution 4: Failure + limiter.Execute(ctx, func(callCtx context.Context) error { + expectedErrors++ + return fakeService(callCtx, 4, true, 10*time.Millisecond, &startedCount) + }) + + err := limiter.Wait() + + assert.Equal(t, int32(4), startedCount) + assert.ErrorContains(t, err, "error from service 2") + assert.ErrorContains(t, err, "error from service 4") +} + +func TestLimiter_ContextCancellation(t *testing.T) { + t.Run("cancellation during execution", func(t *testing.T) { + limiter := NewLimiter(2) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) // Short timeout + defer cancel() + + var startedCount int32 + + // This call should be running when context is cancelled + limiter.Execute(ctx, func(callCtx context.Context) error { + return fakeService(callCtx, 1, false, 200*time.Millisecond, &startedCount) + }) + + // This call should complete before cancellation + limiter.Execute(ctx, func(callCtx context.Context) error { + return fakeService(callCtx, 2, false, 10*time.Millisecond, &startedCount) + }) + + err := limiter.Wait() + + assert.ErrorContains(t, err, "service 1 cancelled: context deadline exceeded") + assert.Equal(t, int32(2), startedCount) + }) + + t.Run("cancellation before acquiring semaphore", func(t *testing.T) { + limiter := NewLimiter(1) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + var startedCount int32 + + // Execute 1: Will acquire semaphore and run (and likely be cancelled) + limiter.Execute(ctx, func(callCtx context.Context) error { + return fakeService(callCtx, 1, false, 100*time.Millisecond, &startedCount) + }) + + // Execute 2: Will wait for semaphore, but context will likely cancel before it acquires + limiter.Execute(ctx, func(callCtx context.Context) error { + return fakeService(callCtx, 2, false, 200*time.Millisecond, &startedCount) + }) + + err := limiter.Wait() + + // Assert that only the first call actually started its service function + // (the second one was cancelled before acquiring the semaphore slot). + assert.Equal(t, int32(1), startedCount, "Only the first call should have started") + + // The first call might be cancelled, leading to 1 error. + // The second call's goroutine will exit via `<-ctx.Done()` *before* acquiring the semaphore, + // and thus won't add an error to the `limiter.err` + assert.ErrorContains(t, err, "service 2 cancelled: context deadline exceeded") + }) +} + +func TestLimiter_NoExecutions(t *testing.T) { + limiter := NewLimiter(3) + err := limiter.Wait() + + assert.NoError(t, err) +} + +func fakeService(ctx context.Context, id int, simulateError bool, duration time.Duration, startedCounter *int32) error { + atomic.AddInt32(startedCounter, 1) + select { + case <-ctx.Done(): + return fmt.Errorf("service %d cancelled: %v", id, ctx.Err()) + case <-time.After(duration): + if simulateError { + return fmt.Errorf("error from service %d", id) + } + return nil + } +}