From 779f157ecfb24d0ee944f18e481bfa8cc8c94f6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Tyczy=C5=84ski?= Date: Tue, 22 Feb 2022 17:14:15 +0100 Subject: [PATCH] Fix potential race in dispatching watch event --- .../apiserver/pkg/storage/cacher/cacher.go | 20 +++++++++---------- .../storage/cacher/cacher_whitebox_test.go | 13 +++--------- .../pkg/storage/cacher/caching_object.go | 4 ++++ 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher.go b/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher.go index fd462dcbae3..84940ab0cd7 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher.go @@ -851,14 +851,14 @@ func (c *Cacher) dispatchEvent(event *watchCacheEvent) { // from it justifies increased memory usage, so for now we drop the cached // serializations after dispatching this event. // - // Given the deep-copies that are done to create cachingObjects, - // we try to cache serializations only if there are at least 3 watchers. - if len(c.watchersBuffer) >= 3 { - // Make a shallow copy to allow overwriting Object and PrevObject. - wcEvent := *event - setCachingObjects(&wcEvent, c.versioner) - event = &wcEvent - } + // Given that CachingObject is just wrapping the object and not perfoming + // deep-copying (until some field is explicitly being modified), we create + // it unconditionally to ensure safety and reduce deep-copying. + // + // Make a shallow copy to allow overwriting Object and PrevObject. + wcEvent := *event + setCachingObjects(&wcEvent, c.versioner) + event = &wcEvent c.blockedWatchers = c.blockedWatchers[:0] for _, watcher := range c.watchersBuffer { @@ -1288,9 +1288,9 @@ func (c *cacheWatcher) convertToWatchEvent(event *watchCacheEvent) *watch.Event switch { case curObjPasses && !oldObjPasses: - return &watch.Event{Type: watch.Added, Object: event.Object} + return &watch.Event{Type: watch.Added, Object: getMutableObject(event.Object)} case curObjPasses && oldObjPasses: - return &watch.Event{Type: watch.Modified, Object: event.Object} + return &watch.Event{Type: watch.Modified, Object: getMutableObject(event.Object)} case !curObjPasses && oldObjPasses: // return a delete event with the previous object content, but with the event's resource version oldObj := getMutableObject(event.PrevObject) diff --git a/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher_whitebox_test.go b/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher_whitebox_test.go index 03ea945691e..b71260102e2 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher_whitebox_test.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/cacher/cacher_whitebox_test.go @@ -1370,17 +1370,10 @@ func testCachingObjects(t *testing.T, watchersCount int) { } var object runtime.Object - if watchersCount >= 3 { - if _, ok := event.Object.(runtime.CacheableObject); !ok { - t.Fatalf("Object in %s event should support caching: %#v", event.Type, event.Object) - } - object = event.Object.(runtime.CacheableObject).GetObject() - } else { - if _, ok := event.Object.(runtime.CacheableObject); ok { - t.Fatalf("Object in %s event should not support caching: %#v", event.Type, event.Object) - } - object = event.Object.DeepCopyObject() + if _, ok := event.Object.(runtime.CacheableObject); !ok { + t.Fatalf("Object in %s event should support caching: %#v", event.Type, event.Object) } + object = event.Object.(runtime.CacheableObject).GetObject() if event.Type == watch.Deleted { resourceVersion, err := cacher.versioner.ObjectResourceVersion(cacher.watchCache.cache[index].PrevObject) diff --git a/staging/src/k8s.io/apiserver/pkg/storage/cacher/caching_object.go b/staging/src/k8s.io/apiserver/pkg/storage/cacher/caching_object.go index 91a22cb459f..9ee5c951f11 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/cacher/caching_object.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/cacher/caching_object.go @@ -137,6 +137,10 @@ func (o *cachingObject) CacheEncode(id runtime.Identifier, encode func(runtime.O result := o.getSerializationResult(id) result.once.Do(func() { buffer := bytes.NewBuffer(nil) + // TODO(wojtek-t): This is currently making a copy to avoid races + // in cases where encoding is making subtle object modifications, + // e.g. #82497 + // Figure out if we can somehow avoid this under some conditions. result.err = encode(o.GetObject(), buffer) result.raw = buffer.Bytes() })