diff --git a/pkg/volume/flexvolume/probe.go b/pkg/volume/flexvolume/probe.go index 2c4c5d0f98c..4b0eca2e0d7 100644 --- a/pkg/volume/flexvolume/probe.go +++ b/pkg/volume/flexvolume/probe.go @@ -40,6 +40,7 @@ type flexVolumeProber struct { factory PluginFactory fs utilfs.Filesystem probeAllNeeded bool + probeAllOnce sync.Once eventsMap map[string]volume.ProbeOperation // the key is the driver directory path, the value is the corresponding operation } @@ -55,7 +56,7 @@ func GetDynamicPluginProber(pluginDir string, runner exec.Interface) volume.Dyna } func (prober *flexVolumeProber) Init() error { - prober.testAndSetProbeAllNeeded(true) + prober.probeAllNeeded = true prober.eventsMap = map[string]volume.ProbeOperation{} if err := prober.createPluginDir(); err != nil { @@ -68,14 +69,18 @@ func (prober *flexVolumeProber) Init() error { return nil } -// If probeAllNeeded is true, probe all pluginDir +// If we haven't yet done so, probe all pluginDir // else probe events in eventsMap func (prober *flexVolumeProber) Probe() (events []volume.ProbeEvent, err error) { - if prober.probeAllNeeded { - prober.testAndSetProbeAllNeeded(false) - return prober.probeAll() + probedAll := false + prober.probeAllOnce.Do(func() { + events, err = prober.probeAll() + probedAll = true + prober.probeAllNeeded = false + }) + if probedAll { + return events, err } - return prober.probeMap() } @@ -278,10 +283,3 @@ func (prober *flexVolumeProber) createPluginDir() error { return nil } - -func (prober *flexVolumeProber) testAndSetProbeAllNeeded(newval bool) (oldval bool) { - prober.mutex.Lock() - defer prober.mutex.Unlock() - oldval, prober.probeAllNeeded = prober.probeAllNeeded, newval - return -} diff --git a/pkg/volume/flexvolume/probe_test.go b/pkg/volume/flexvolume/probe_test.go index 7ff57e9075c..6d591392b69 100644 --- a/pkg/volume/flexvolume/probe_test.go +++ b/pkg/volume/flexvolume/probe_test.go @@ -21,6 +21,8 @@ import ( "path/filepath" goruntime "runtime" "strings" + "sync" + "sync/atomic" "testing" "github.com/fsnotify/fsnotify" @@ -327,6 +329,47 @@ func TestProberSuccessAndError(t *testing.T) { assert.Error(t, err) } +// TestProberMultiThreaded tests the code path of many callers calling FindPluginBySpec/FindPluginByName +// which then calls refreshProbedPlugins which then calls prober.Probe() and ensures that the prober is thread safe +func TestProberMultiThreaded(t *testing.T) { + // Arrange + _, _, _, prober := initTestEnvironment(t) + totalEvents := atomic.Int32{} + totalErrors := atomic.Int32{} + pluginNameMutex := sync.RWMutex{} + var pluginName string + var wg sync.WaitGroup + + // Act + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + events, err := prober.Probe() + for _, event := range events { + if event.Op == volume.ProbeAddOrUpdate { + pluginNameMutex.Lock() + pluginName = event.Plugin.GetPluginName() + pluginNameMutex.Unlock() + } + } + // this fails if ProbeAll is not complete before the next call comes in but we have assumed that it has + pluginNameMutex.RLock() + assert.Equal(t, "fake-driver", pluginName) + pluginNameMutex.RUnlock() + totalEvents.Add(int32(len(events))) + if err != nil { + totalErrors.Add(1) + } + }() + wg.Add(1) + } + wg.Wait() + + // Assert + assert.Equal(t, int32(1), totalEvents.Load()) + assert.Equal(t, int32(0), totalErrors.Load()) +} + // Installs a mock driver (an empty file) in the mock fs. func installDriver(driverName string, fs utilfs.Filesystem) { driverPath := filepath.Join(pluginDir, driverName) diff --git a/pkg/volume/plugins.go b/pkg/volume/plugins.go index ee9a692ed22..0130f91f475 100644 --- a/pkg/volume/plugins.go +++ b/pkg/volume/plugins.go @@ -627,6 +627,7 @@ func (pm *VolumePluginMgr) initProbedPlugin(probedPlugin VolumePlugin) error { // specification. If no plugins can support or more than one plugin can // support it, return error. func (pm *VolumePluginMgr) FindPluginBySpec(spec *Spec) (VolumePlugin, error) { + pm.refreshProbedPlugins() pm.mutex.RLock() defer pm.mutex.RUnlock() @@ -643,7 +644,6 @@ func (pm *VolumePluginMgr) FindPluginBySpec(spec *Spec) (VolumePlugin, error) { } } - pm.refreshProbedPlugins() for _, plugin := range pm.probedPlugins { if plugin.CanSupport(spec) { match = plugin @@ -663,6 +663,7 @@ func (pm *VolumePluginMgr) FindPluginBySpec(spec *Spec) (VolumePlugin, error) { // FindPluginByName fetches a plugin by name. If no plugin is found, returns error. func (pm *VolumePluginMgr) FindPluginByName(name string) (VolumePlugin, error) { + pm.refreshProbedPlugins() pm.mutex.RLock() defer pm.mutex.RUnlock() @@ -671,7 +672,6 @@ func (pm *VolumePluginMgr) FindPluginByName(name string) (VolumePlugin, error) { match = v } - pm.refreshProbedPlugins() if plugin, found := pm.probedPlugins[name]; found { if match != nil { return nil, fmt.Errorf("multiple volume plugins matched: %s and %s", match.GetPluginName(), plugin.GetPluginName()) @@ -694,6 +694,12 @@ func (pm *VolumePluginMgr) refreshProbedPlugins() { klog.ErrorS(err, "Error dynamically probing plugins") } + if len(events) == 0 { + return + } + + pm.mutex.Lock() + defer pm.mutex.Unlock() // because the probe function can return a list of valid plugins // even when an error is present we still must add the plugins // or they will be skipped because each event only fires once diff --git a/pkg/volume/plugins_test.go b/pkg/volume/plugins_test.go index fa638ff138a..b3d03737a20 100644 --- a/pkg/volume/plugins_test.go +++ b/pkg/volume/plugins_test.go @@ -17,6 +17,10 @@ limitations under the License. package volume import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "sync" + "sync/atomic" "testing" v1 "k8s.io/api/core/v1" @@ -165,3 +169,62 @@ func Test_ValidatePodTemplate(t *testing.T) { t.Errorf("isPodTemplateValid(%v) returned (%v), want (%v)", pod.String(), got, "Error: pod specification does not contain any volume(s).") } } + +// TestVolumePluginMultiThreaded tests FindPluginByName/FindPluginBySpec in a multi-threaded environment. +// If these are called by different threads at the same time, they should still be able to reconcile the plugins +// and return the same results (no missing plugin) +func TestVolumePluginMultiThreaded(t *testing.T) { + vpm := VolumePluginMgr{} + var prober DynamicPluginProber = &fakeProber{events: []ProbeEvent{{PluginName: testPluginName, Op: ProbeAddOrUpdate, Plugin: &testPlugins{}}}} + err := vpm.InitPlugins([]VolumePlugin{}, prober, nil) + require.NoError(t, err) + + volumeSpec := &Spec{} + totalErrors := atomic.Int32{} + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + _, err := vpm.FindPluginByName(testPluginName) + if err != nil { + totalErrors.Add(1) + } + }() + wg.Add(1) + } + wg.Wait() + + assert.Equal(t, int32(0), totalErrors.Load()) + + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + _, err := vpm.FindPluginBySpec(volumeSpec) + if err != nil { + totalErrors.Add(1) + } + }() + wg.Add(1) + } + wg.Wait() + + assert.Equal(t, int32(0), totalErrors.Load()) +} + +type fakeProber struct { + events []ProbeEvent + firstExecution atomic.Bool +} + +func (prober *fakeProber) Init() error { + prober.firstExecution.Store(true) + return nil +} + +func (prober *fakeProber) Probe() (events []ProbeEvent, err error) { + if prober.firstExecution.CompareAndSwap(true, false) { + return prober.events, nil + } + return []ProbeEvent{}, nil +}