diff --git a/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher.go b/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher.go index 954a787ac42..96359d96d3b 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher.go @@ -47,7 +47,7 @@ import ( ) var ( - emptyFunc = func() {} + emptyFunc = func(bool) {} ) const ( @@ -147,6 +147,10 @@ func (i *indexedWatchers) deleteWatcher(number int, value string, supported bool } func (i *indexedWatchers) terminateAll(objectType reflect.Type, done func(*cacheWatcher)) { + // note that we don't have to call setDrainInputBufferLocked method on the watchers + // because we take advantage of the default value - stop immediately + // also watchers that have had already its draining strategy set + // are no longer available (they were removed from the allWatchers and the valueWatchers maps) if len(i.allWatchers) > 0 || len(i.valueWatchers) > 0 { klog.Warningf("Terminating all watchers from cacher %v", objectType) } @@ -183,6 +187,10 @@ func newTimeBucketWatchers(clock clock.Clock, bookmarkFrequency time.Duration) * // adds a watcher to the bucket, if the deadline is before the start, it will be // added to the first one. func (t *watcherBookmarkTimeBuckets) addWatcher(w *cacheWatcher) bool { + // note that the returned time can be before t.createTime, + // especially in cases when the nextBookmarkTime method + // give us the zero value of type Time + // so buckedID can hold a negative value nextTime, ok := w.nextBookmarkTime(t.clock.Now(), t.bookmarkFrequency) if !ok { return false @@ -517,7 +525,7 @@ func (c *Cacher) Watch(ctx context.Context, key string, opts storage.ListOptions c.Lock() defer c.Unlock() // Update watcher.forget function once we can compute it. - watcher.forget = forgetWatcher(c, c.watcherIdx, triggerValue, triggerSupported) + watcher.forget = forgetWatcher(c, watcher, c.watcherIdx, triggerValue, triggerSupported) c.watchers.addWatcher(watcher, c.watcherIdx, triggerValue, triggerSupported) // Add it to the queue only when the client support watch bookmarks. @@ -1028,11 +1036,13 @@ func (c *Cacher) Stop() { c.stopWg.Wait() } -func forgetWatcher(c *Cacher, index int, triggerValue string, triggerSupported bool) func() { - return func() { +func forgetWatcher(c *Cacher, w *cacheWatcher, index int, triggerValue string, triggerSupported bool) func(bool) { + return func(drainWatcher bool) { c.Lock() defer c.Unlock() + w.setDrainInputBufferLocked(drainWatcher) + // It's possible that the watcher is already not in the structure (e.g. in case of // simultaneous Stop() and terminateAllWatchers(), but it is safe to call stopLocked() // on a watcher multiple times. @@ -1156,7 +1166,7 @@ type cacheWatcher struct { done chan struct{} filter filterWithAttrsFunc stopped bool - forget func() + forget func(bool) versioner storage.Versioner // The watcher will be closed by server after the deadline, // save it here to send bookmark events before that. @@ -1168,9 +1178,13 @@ type cacheWatcher struct { // human readable identifier that helps assigning cacheWatcher // instance with request identifier string + + // drainInputBuffer indicates whether we should delay closing this watcher + // and send all event in the input buffer. + drainInputBuffer bool } -func newCacheWatcher(chanSize int, filter filterWithAttrsFunc, forget func(), versioner storage.Versioner, deadline time.Time, allowWatchBookmarks bool, objectType reflect.Type, identifier string) *cacheWatcher { +func newCacheWatcher(chanSize int, filter filterWithAttrsFunc, forget func(bool), versioner storage.Versioner, deadline time.Time, allowWatchBookmarks bool, objectType reflect.Type, identifier string) *cacheWatcher { return &cacheWatcher{ input: make(chan *watchCacheEvent, chanSize), result: make(chan watch.Event, chanSize), @@ -1193,16 +1207,29 @@ func (c *cacheWatcher) ResultChan() <-chan watch.Event { // Implements watch.Interface. func (c *cacheWatcher) Stop() { - c.forget() + c.forget(false) } // we rely on the fact that stopLocked is actually protected by Cacher.Lock() func (c *cacheWatcher) stopLocked() { if !c.stopped { c.stopped = true - close(c.done) + // stop without draining the input channel was requested. + if !c.drainInputBuffer { + close(c.done) + } close(c.input) } + + // Even if the watcher was already stopped, if it previously was + // using draining mode and it's not using it now we need to + // close the done channel now. Otherwise we could leak the + // processing goroutine if it will be trying to put more objects + // into result channel, the channel will be full and there will + // already be noone on the processing the events on the receiving end. + if !c.drainInputBuffer && !c.isDoneChannelClosedLocked() { + close(c.done) + } } func (c *cacheWatcher) nonblockingAdd(event *watchCacheEvent) bool { @@ -1227,7 +1254,7 @@ func (c *cacheWatcher) add(event *watchCacheEvent, timer *time.Timer) bool { // we simply terminate it. klog.V(1).Infof("Forcing %v watcher close due to unresponsiveness: %v. len(c.input) = %v, len(c.result) = %v", c.objectType.String(), c.identifier, len(c.input), len(c.result)) metrics.TerminatedWatchersCounter.WithLabelValues(c.objectType.String()).Inc() - c.forget() + c.forget(false) } if timer == nil { @@ -1273,6 +1300,22 @@ func (c *cacheWatcher) nextBookmarkTime(now time.Time, bookmarkFrequency time.Du return heartbeatTime, true } +// setDrainInputBufferLocked if set to true indicates that we should delay closing this watcher +// until we send all events residing in the input buffer. +func (c *cacheWatcher) setDrainInputBufferLocked(drain bool) { + c.drainInputBuffer = drain +} + +// isDoneChannelClosed checks if c.done channel is closed +func (c *cacheWatcher) isDoneChannelClosedLocked() bool { + select { + case <-c.done: + return true + default: + } + return false +} + func getMutableObject(object runtime.Object) runtime.Object { if _, ok := object.(*cachingObject); ok { // It is safe to return without deep-copy, because the underlying diff --git a/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher_whitebox_test.go b/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher_whitebox_test.go index 1616fe6b1f8..7d57855a213 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher_whitebox_test.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher_whitebox_test.go @@ -58,13 +58,14 @@ func TestCacheWatcherCleanupNotBlockedByResult(t *testing.T) { var w *cacheWatcher count := 0 filter := func(string, labels.Set, fields.Set) bool { return true } - forget := func() { + forget := func(drainWatcher bool) { lock.Lock() defer lock.Unlock() count++ // forget() has to stop the watcher, as only stopping the watcher // triggers stopping the process() goroutine which we are in the // end waiting for in this test. + w.setDrainInputBufferLocked(drainWatcher) w.stopLocked() } initEvents := []*watchCacheEvent{ @@ -89,7 +90,7 @@ func TestCacheWatcherHandlesFiltering(t *testing.T) { filter := func(_ string, _ labels.Set, field fields.Set) bool { return field["spec.nodeName"] == "host" } - forget := func() {} + forget := func(bool) {} testCases := []struct { events []*watchCacheEvent @@ -210,6 +211,7 @@ TestCase: break TestCase default: } + w.setDrainInputBufferLocked(false) w.stopLocked() } } @@ -524,7 +526,8 @@ func TestCacheWatcherStoppedInAnotherGoroutine(t *testing.T) { var w *cacheWatcher done := make(chan struct{}) filter := func(string, labels.Set, fields.Set) bool { return true } - forget := func() { + forget := func(drainWatcher bool) { + w.setDrainInputBufferLocked(drainWatcher) w.stopLocked() done <- struct{}{} } @@ -556,6 +559,7 @@ func TestCacheWatcherStoppedInAnotherGoroutine(t *testing.T) { case <-time.After(time.Second): t.Fatal("expected received a event on ResultChan") } + w.setDrainInputBufferLocked(false) w.stopLocked() } } @@ -667,7 +671,7 @@ func TestTimeBucketWatchersBasic(t *testing.T) { filter := func(_ string, _ labels.Set, _ fields.Set) bool { return true } - forget := func() {} + forget := func(bool) {} newWatcher := func(deadline time.Time) *cacheWatcher { return newCacheWatcher(0, filter, forget, testVersioner{}, deadline, true, objectType, "") @@ -1581,3 +1585,72 @@ func TestCacheIntervalInvalidationStopsWatch(t *testing.T) { t.Errorf("unexpected number of events received, expected: %d, got: %d", bufferSize+1, received) } } + +// TestCacheWatcherDraining verifies the cacheWatcher.process goroutine is properly cleaned up when draining was requested +func TestCacheWatcherDraining(t *testing.T) { + var lock sync.RWMutex + var w *cacheWatcher + count := 0 + filter := func(string, labels.Set, fields.Set) bool { return true } + forget := func(drainWatcher bool) { + lock.Lock() + defer lock.Unlock() + count++ + w.setDrainInputBufferLocked(drainWatcher) + w.stopLocked() + } + initEvents := []*watchCacheEvent{ + {Object: &v1.Pod{}}, + {Object: &v1.Pod{}}, + } + w = newCacheWatcher(1, filter, forget, testVersioner{}, time.Now(), true, objectType, "") + go w.processInterval(context.Background(), intervalFromEvents(initEvents), 0) + if !w.add(&watchCacheEvent{Object: &v1.Pod{}}, time.NewTimer(1*time.Second)) { + t.Fatal("failed adding an even to the watcher") + } + forget(true) // drain the watcher + <-w.ResultChan() + <-w.ResultChan() + <-w.ResultChan() + if err := wait.PollImmediate(1*time.Second, 5*time.Second, func() (bool, error) { + lock.RLock() + defer lock.RUnlock() + return count == 2, nil + }); err != nil { + t.Fatalf("expected forget() to be called twice, because processInterval should call Stop(): %v", err) + } +} + +// TestCacheWatcherDrainingRequestedButNotDrained verifies the cacheWatcher.process goroutine is properly cleaned up when draining was requested +// but the client never actually get any data +func TestCacheWatcherDrainingRequestedButNotDrained(t *testing.T) { + var lock sync.RWMutex + var w *cacheWatcher + count := 0 + filter := func(string, labels.Set, fields.Set) bool { return true } + forget := func(drainWatcher bool) { + lock.Lock() + defer lock.Unlock() + count++ + w.setDrainInputBufferLocked(drainWatcher) + w.stopLocked() + } + initEvents := []*watchCacheEvent{ + {Object: &v1.Pod{}}, + {Object: &v1.Pod{}}, + } + w = newCacheWatcher(1, filter, forget, testVersioner{}, time.Now(), true, objectType, "") + go w.processInterval(context.Background(), intervalFromEvents(initEvents), 0) + if !w.add(&watchCacheEvent{Object: &v1.Pod{}}, time.NewTimer(1*time.Second)) { + t.Fatal("failed adding an even to the watcher") + } + forget(true) // drain the watcher + w.Stop() // client disconnected, timeout expired or ctx was actually closed + if err := wait.PollImmediate(1*time.Second, 5*time.Second, func() (bool, error) { + lock.RLock() + defer lock.RUnlock() + return count == 3, nil + }); err != nil { + t.Fatalf("expected forget() to be called three times, because processInterval should call Stop(): %v", err) + } +}