diff --git a/pkg/util/time_cache.go b/pkg/util/time_cache.go index 7a9bf4a9d0b..bcb06231bc2 100644 --- a/pkg/util/time_cache.go +++ b/pkg/util/time_cache.go @@ -40,9 +40,13 @@ type timeCacheEntry struct { type timeCache struct { clock Clock fillFunc func(string) T - cache map[string]timeCacheEntry - lock sync.Mutex ttl time.Duration + + inFlight map[string]chan T + inFlightLock sync.Mutex + + cache map[string]timeCacheEntry + lock sync.RWMutex } // NewTimeCache returns a cache which calls fill to fill its entries, and @@ -51,6 +55,7 @@ func NewTimeCache(clock Clock, ttl time.Duration, fill func(key string) T) TimeC return &timeCache{ clock: clock, fillFunc: fill, + inFlight: map[string]chan T{}, cache: map[string]timeCacheEntry{}, ttl: ttl, } @@ -59,17 +64,62 @@ func NewTimeCache(clock Clock, ttl time.Duration, fill func(key string) T) TimeC // Get returns the value of key from the cache, if it is present // and recent enough; otherwise, it blocks while it gets the value. func (c *timeCache) Get(key string) T { - c.lock.Lock() - defer c.lock.Unlock() + if item, ok := c.get(key); ok { + return item + } + + // We need to fill the cache. Calling the function could be + // expensive, so do it while unlocked. + wait := c.fillOrWait(key) + item := <-wait + + // Put it back in the channel in case there's multiple waiters + // (this channel is non-blocking) + wait <- item + return item +} + +// returns the item and true if it is found and not expired, otherwise nil and false. +func (c *timeCache) get(key string) (T, bool) { + c.lock.RLock() + defer c.lock.RUnlock() data, ok := c.cache[key] now := c.clock.Now() - if !ok || now.Sub(data.lastUpdate) > c.ttl { - data = timeCacheEntry{ - item: c.fillFunc(key), - lastUpdate: now, - } - c.cache[key] = data + return nil, false } - return data.item + return data.item, true +} + +func (c *timeCache) fillOrWait(key string) chan T { + c.inFlightLock.Lock() + defer c.inFlightLock.Unlock() + + // Already a call in progress? + if current, ok := c.inFlight[key]; ok { + return current + } + + // We are the first, so we have to make the call. + result := make(chan T, 1) // non-blocking + c.inFlight[key] = result + go func() { + // Make potentially slow call + data := timeCacheEntry{ + item: c.fillFunc(key), + lastUpdate: c.clock.Now(), + } + result <- data.item + + // Store in cache + c.lock.Lock() + c.cache[key] = data + c.lock.Unlock() + + // Remove in flight entry + c.inFlightLock.Lock() + delete(c.inFlight, key) + c.inFlightLock.Unlock() + }() + return result } diff --git a/pkg/util/time_cache_test.go b/pkg/util/time_cache_test.go index 1f42b08dcdd..7e8f248d68f 100644 --- a/pkg/util/time_cache_test.go +++ b/pkg/util/time_cache_test.go @@ -17,8 +17,11 @@ limitations under the License. package util import ( + "sync" "testing" "time" + + fuzz "github.com/google/gofuzz" ) func TestCacheExpire(t *testing.T) { @@ -61,3 +64,54 @@ func TestCacheNotExpire(t *testing.T) { t.Errorf("Wrong number of calls for foo: wanted %v, got %v", e, a) } } + +func TestCacheParallel(t *testing.T) { + ff := func(key string) T { time.Sleep(time.Second); return key } + clock := &FakeClock{time.Now()} + c := NewTimeCache(clock, 60*time.Second, ff) + + // Make some keys + keys := []string{} + fuzz.New().NilChance(0).NumElements(50, 50).Fuzz(&keys) + + // If we have high parallelism, this will take only a second. + var wg sync.WaitGroup + wg.Add(len(keys)) + for _, key := range keys { + go func(key string) { + c.Get(key) + wg.Done() + }(key) + } + wg.Wait() +} + +func TestCacheParallelOneCall(t *testing.T) { + calls := 0 + var callLock sync.Mutex + ff := func(key string) T { + time.Sleep(time.Second) + callLock.Lock() + defer callLock.Unlock() + calls++ + return key + } + clock := &FakeClock{time.Now()} + c := NewTimeCache(clock, 60*time.Second, ff) + + // If we have high parallelism, this will take only a second. + var wg sync.WaitGroup + wg.Add(50) + for i := 0; i < 50; i++ { + go func(key string) { + c.Get(key) + wg.Done() + }("aoeu") + } + wg.Wait() + + // And if we wait for existing calls, we should have only one call. + if e, a := 1, calls; e != a { + t.Errorf("Expected %v, got %v", e, a) + } +}