diff --git a/pkg/scheduler/core/equivalence_cache.go b/pkg/scheduler/core/equivalence_cache.go index 1eb7df685d7..1b6183ba99f 100644 --- a/pkg/scheduler/core/equivalence_cache.go +++ b/pkg/scheduler/core/equivalence_cache.go @@ -38,7 +38,7 @@ const maxCacheEntries = 100 // 1. a map of AlgorithmCache with node name as key // 2. function to get equivalence pod type EquivalenceCache struct { - sync.RWMutex + mu sync.Mutex algorithmCache map[string]AlgorithmCache } @@ -84,9 +84,9 @@ func (ec *EquivalenceCache) RunPredicate( equivClassInfo *equivalenceClassInfo, cache schedulercache.Cache, ) (bool, []algorithm.PredicateFailureReason, error) { - ec.Lock() - defer ec.Unlock() - fit, reasons, invalid := ec.lookupResult(pod.GetName(), nodeInfo.Node().GetName(), predicateKey, equivClassInfo.hash, false) + ec.mu.Lock() + defer ec.mu.Unlock() + fit, reasons, invalid := ec.lookupResult(pod.GetName(), nodeInfo.Node().GetName(), predicateKey, equivClassInfo.hash) if !invalid { return fit, reasons, nil } @@ -96,7 +96,7 @@ func (ec *EquivalenceCache) RunPredicate( } // Skip update if NodeInfo is stale. if cache != nil && cache.IsUpToDate(nodeInfo) { - ec.updateResult(pod.GetName(), nodeInfo.Node().GetName(), predicateKey, fit, reasons, equivClassInfo.hash, false) + ec.updateResult(pod.GetName(), nodeInfo.Node().GetName(), predicateKey, fit, reasons, equivClassInfo.hash) } return fit, reasons, nil } @@ -107,12 +107,7 @@ func (ec *EquivalenceCache) updateResult( fit bool, reasons []algorithm.PredicateFailureReason, equivalenceHash uint64, - needLock bool, ) { - if needLock { - ec.Lock() - defer ec.Unlock() - } if _, exist := ec.algorithmCache[nodeName]; !exist { ec.algorithmCache[nodeName] = newAlgorithmCache() } @@ -140,12 +135,8 @@ func (ec *EquivalenceCache) updateResult( // 3. if cache item is not found func (ec *EquivalenceCache) lookupResult( podName, nodeName, predicateKey string, - equivalenceHash uint64, needLock bool, + equivalenceHash uint64, ) (bool, []algorithm.PredicateFailureReason, bool) { - if needLock { - ec.RLock() - defer ec.RUnlock() - } glog.V(5).Infof("Begin to calculate predicate: %v for pod: %s on node: %s based on equivalence cache", predicateKey, podName, nodeName) if algorithmCache, exist := ec.algorithmCache[nodeName]; exist { @@ -170,8 +161,8 @@ func (ec *EquivalenceCache) InvalidateCachedPredicateItem(nodeName string, predi if len(predicateKeys) == 0 { return } - ec.Lock() - defer ec.Unlock() + ec.mu.Lock() + defer ec.mu.Unlock() if algorithmCache, exist := ec.algorithmCache[nodeName]; exist { for predicateKey := range predicateKeys { algorithmCache.predicatesCache.Remove(predicateKey) @@ -185,8 +176,8 @@ func (ec *EquivalenceCache) InvalidateCachedPredicateItemOfAllNodes(predicateKey if len(predicateKeys) == 0 { return } - ec.Lock() - defer ec.Unlock() + ec.mu.Lock() + defer ec.mu.Unlock() // algorithmCache uses nodeName as key, so we just iterate it and invalid given predicates for _, algorithmCache := range ec.algorithmCache { for predicateKey := range predicateKeys { @@ -199,8 +190,8 @@ func (ec *EquivalenceCache) InvalidateCachedPredicateItemOfAllNodes(predicateKey // InvalidateAllCachedPredicateItemOfNode marks all cached items on given node as invalid func (ec *EquivalenceCache) InvalidateAllCachedPredicateItemOfNode(nodeName string) { - ec.Lock() - defer ec.Unlock() + ec.mu.Lock() + defer ec.mu.Unlock() delete(ec.algorithmCache, nodeName) glog.V(5).Infof("Done invalidating all cached predicates on node: %s", nodeName) } diff --git a/pkg/scheduler/core/equivalence_cache_test.go b/pkg/scheduler/core/equivalence_cache_test.go index 512fbf2afb0..0129fd2db2a 100644 --- a/pkg/scheduler/core/equivalence_cache_test.go +++ b/pkg/scheduler/core/equivalence_cache_test.go @@ -253,7 +253,9 @@ func TestRunPredicate(t *testing.T) { ecache := NewEquivalenceCache() equivClass := ecache.getEquivalenceClassInfo(pod) if test.expectCacheHit { + ecache.mu.Lock() ecache.updateResult(pod.Name, node.Node().Name, "testPredicate", test.expectFit, test.expectedReasons, equivClass.hash) + ecache.mu.Unlock() } fit, reasons, err := ecache.RunPredicate(test.pred.predicate, "testPredicate", pod, meta, node, equivClass, test.cache) @@ -287,7 +289,9 @@ func TestRunPredicate(t *testing.T) { if !test.expectCacheHit && test.pred.callCount == 0 { t.Errorf("Predicate should be called") } + ecache.mu.Lock() _, _, invalid := ecache.lookupResult(pod.Name, node.Node().Name, "testPredicate", equivClass.hash) + ecache.mu.Unlock() if invalid && test.expectCacheWrite { t.Errorf("Cache write should happen") } @@ -350,6 +354,7 @@ func TestUpdateResult(t *testing.T) { test.equivalenceHash: predicateItem, }) } + ecache.mu.Lock() ecache.updateResult( test.pod, test.nodeName, @@ -357,8 +362,8 @@ func TestUpdateResult(t *testing.T) { test.fit, test.reasons, test.equivalenceHash, - true, ) + ecache.mu.Unlock() value, ok := ecache.algorithmCache[test.nodeName].predicatesCache.Get(test.predicateKey) if !ok { @@ -460,6 +465,7 @@ func TestLookupResult(t *testing.T) { for _, test := range tests { ecache := NewEquivalenceCache() // set cached item to equivalence cache + ecache.mu.Lock() ecache.updateResult( test.podName, test.nodeName, @@ -467,8 +473,8 @@ func TestLookupResult(t *testing.T) { test.cachedItem.fit, test.cachedItem.reasons, test.equivalenceHashForUpdatePredicate, - true, ) + ecache.mu.Unlock() // if we want to do invalid, invalid the cached item if test.expectedInvalidPredicateKey { predicateKeys := sets.NewString() @@ -476,12 +482,13 @@ func TestLookupResult(t *testing.T) { ecache.InvalidateCachedPredicateItem(test.nodeName, predicateKeys) } // calculate predicate with equivalence cache + ecache.mu.Lock() fit, reasons, invalid := ecache.lookupResult(test.podName, test.nodeName, test.predicateKey, test.equivalenceHashForCalPredicate, - true, ) + ecache.mu.Unlock() // returned invalid should match expectedInvalidPredicateKey or expectedInvalidEquivalenceHash if test.equivalenceHashForUpdatePredicate != test.equivalenceHashForCalPredicate { if invalid != test.expectedInvalidEquivalenceHash { @@ -668,6 +675,7 @@ func TestInvalidateCachedPredicateItemOfAllNodes(t *testing.T) { for _, test := range tests { // set cached item to equivalence cache + ecache.mu.Lock() ecache.updateResult( test.podName, test.nodeName, @@ -675,8 +683,8 @@ func TestInvalidateCachedPredicateItemOfAllNodes(t *testing.T) { test.cachedItem.fit, test.cachedItem.reasons, test.equivalenceHashForUpdatePredicate, - true, ) + ecache.mu.Unlock() } // invalidate cached predicate for all nodes @@ -735,6 +743,7 @@ func TestInvalidateAllCachedPredicateItemOfNode(t *testing.T) { for _, test := range tests { // set cached item to equivalence cache + ecache.mu.Lock() ecache.updateResult( test.podName, test.nodeName, @@ -742,8 +751,8 @@ func TestInvalidateAllCachedPredicateItemOfNode(t *testing.T) { test.cachedItem.fit, test.cachedItem.reasons, test.equivalenceHashForUpdatePredicate, - true, ) + ecache.mu.Unlock() } for _, test := range tests {