diff --git a/pkg/util/time_cache.go b/pkg/util/time_cache.go index bcb06231bc2..f939294e027 100644 --- a/pkg/util/time_cache.go +++ b/pkg/util/time_cache.go @@ -80,19 +80,25 @@ func (c *timeCache) Get(key string) T { } // returns the item and true if it is found and not expired, otherwise nil and false. +// If this returns false, it has locked c.inFlightLock and it is caller's responsibility +// to unlock that. func (c *timeCache) get(key string) (T, bool) { c.lock.RLock() defer c.lock.RUnlock() data, ok := c.cache[key] now := c.clock.Now() if !ok || now.Sub(data.lastUpdate) > c.ttl { + // We must lock this while we hold c.lock-- otherwise, a writer could + // write to c.cache and remove the channel from c.inFlight before we + // manage to read c.inFlight. + c.inFlightLock.Lock() return nil, false } return data.item, true } +// c.inFlightLock MUST be locked before calling this. fillOrWait will unlock it. func (c *timeCache) fillOrWait(key string) chan T { - c.inFlightLock.Lock() defer c.inFlightLock.Unlock() // Already a call in progress? @@ -104,7 +110,9 @@ func (c *timeCache) fillOrWait(key string) chan T { result := make(chan T, 1) // non-blocking c.inFlight[key] = result go func() { - // Make potentially slow call + // Make potentially slow call. + // While this call is in flight, fillOrWait will + // presumably exit. data := timeCacheEntry{ item: c.fillFunc(key), lastUpdate: c.clock.Now(), @@ -113,13 +121,13 @@ func (c *timeCache) fillOrWait(key string) chan T { // Store in cache c.lock.Lock() + defer c.lock.Unlock() c.cache[key] = data - c.lock.Unlock() // Remove in flight entry c.inFlightLock.Lock() + defer c.inFlightLock.Unlock() delete(c.inFlight, key) - c.inFlightLock.Unlock() }() return result } diff --git a/pkg/util/time_cache_test.go b/pkg/util/time_cache_test.go index 7e8f248d68f..0016dcbd810 100644 --- a/pkg/util/time_cache_test.go +++ b/pkg/util/time_cache_test.go @@ -17,6 +17,8 @@ limitations under the License. package util import ( + "math/rand" + "runtime" "sync" "testing" "time" @@ -115,3 +117,67 @@ func TestCacheParallelOneCall(t *testing.T) { t.Errorf("Expected %v, got %v", e, a) } } + +func TestCacheParallelNoDeadlocksNoDoubleCalls(t *testing.T) { + // Make 50 random keys + keys := []string{} + fuzz.New().NilChance(0).NumElements(50, 50).Fuzz(&keys) + + // Data structure for tracking when each key is accessed. + type callTrack struct { + sync.Mutex + accessTimes []time.Time + } + calls := map[string]*callTrack{} + for _, k := range keys { + calls[k] = &callTrack{} + } + + // This is called to fill the cache in the case of a cache miss + // or cache entry expiration. We record the time. + ff := func(key string) T { + ct := calls[key] + ct.Lock() + ct.accessTimes = append(ct.accessTimes, time.Now()) + ct.Unlock() + // make sure that there is time for multiple requests to come in + // for the same key before this returns. + time.Sleep(time.Millisecond) + return key + } + + cacheDur := 10 * time.Millisecond + c := NewTimeCache(RealClock{}, cacheDur, ff) + + // Spawn a bunch of goroutines, each of which sequentially requests + // 500 random keys from the cache. + runtime.GOMAXPROCS(16) + var wg sync.WaitGroup + for i := 0; i < 500; i++ { + wg.Add(1) + go func(seed int64) { + r := rand.New(rand.NewSource(seed)) + for i := 0; i < 500; i++ { + c.Get(keys[r.Intn(len(keys))]) + } + wg.Done() + }(rand.Int63()) + } + wg.Wait() + + // Since the cache should hold things for 10ms, no calls for a given key + // should be more closely spaced than that. + for k, ct := range calls { + if len(ct.accessTimes) < 2 { + continue + } + cur := ct.accessTimes[0] + for i := 1; i < len(ct.accessTimes); i++ { + next := ct.accessTimes[i] + if next.Sub(cur) < cacheDur { + t.Errorf("%v was called at %v and %v", k, cur, next) + } + cur = next + } + } +}