wait: Use a context implementation for ContextForChannel

ContextForChannel uses a goroutine to transform a channel close to
a context cancel. However, this exposes a synchronization issue if
we want to unify the underlying implementation between contextless
and with context - a ConditionFunc that closes the channel today
expects the behavior that no subsequent conditions will be invoked
(we have a test in wait_test.go TestUntilReturnsImmediately that
verifies this expectation). We can't unify the implementation
without ensuring this property holds.

To do that this commit changes from the goroutine propagation to
implementing context.Context and using stopCh as the Done(). We
then implement Err() by returning context.Canceled and stub the
other methods. Since our context cannot be explicitly cancelled
by users, we cease to return the cancelFn and callers that need
that behavior must wrap the context as normal.

This should be invisible to clients - they would already observe
the same behavior from the context, and the existing error
behavior of Poll* is preserved (which ignores ctx.Err()).

As a side effect, one less goroutine is created making it more
efficient.
This commit is contained in:
Clayton Coleman 2023-01-17 15:01:02 -05:00
parent 5e9fc39d17
commit 95051a63b3
No known key found for this signature in database
GPG Key ID: CF7DB7FC943D3E0E
5 changed files with 38 additions and 37 deletions

View File

@ -242,7 +242,7 @@ func Run(c *config.CompletedConfig, stopCh <-chan struct{}) error {
// No leader election, run directly
if !c.ComponentConfig.Generic.LeaderElection.LeaderElect {
ctx, _ := wait.ContextForChannel(stopCh)
ctx := wait.ContextForChannel(stopCh)
run(ctx, saTokenControllerInitFunc, NewControllerInitializers)
return nil
}

View File

@ -223,6 +223,33 @@ func (cf ConditionFunc) WithContext() ConditionWithContextFunc {
}
}
// ContextForChannel provides a context that will be treated as cancelled
// when the provided parentCh is closed. The implementation returns
// context.Canceled for Err() if and only if the parentCh is closed.
func ContextForChannel(parentCh <-chan struct{}) context.Context {
return channelContext{stopCh: parentCh}
}
var _ context.Context = channelContext{}
// channelContext will behave as if the context were cancelled when stopCh is
// closed.
type channelContext struct {
stopCh <-chan struct{}
}
func (c channelContext) Done() <-chan struct{} { return c.stopCh }
func (c channelContext) Err() error {
select {
case <-c.stopCh:
return context.Canceled
default:
return nil
}
}
func (c channelContext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (c channelContext) Value(key any) any { return nil }
// runConditionWithCrashProtection runs a ConditionFunc with crash protection
func runConditionWithCrashProtection(condition ConditionFunc) (bool, error) {
return runConditionWithCrashProtectionWithContext(context.TODO(), condition.WithContext())
@ -290,25 +317,6 @@ func (b *Backoff) Step() time.Duration {
return duration
}
// ContextForChannel derives a child context from a parent channel.
//
// The derived context's Done channel is closed when the returned cancel function
// is called or when the parent channel is closed, whichever happens first.
//
// Note the caller must *always* call the CancelFunc, otherwise resources may be leaked.
func ContextForChannel(parentCh <-chan struct{}) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-parentCh:
cancel()
case <-ctx.Done():
}
}()
return ctx, cancel
}
// BackoffManager manages backoff with a particular scheme based on its underlying implementation. It provides
// an interface to return a timer for backoff, and caller shall backoff until Timer.C() drains. If the second Backoff()
// is called before the timer from the first Backoff() call finishes, the first timer will NOT be drained and result in
@ -466,9 +474,7 @@ func PollWithContext(ctx context.Context, interval, timeout time.Duration, condi
// PollUntil always waits interval before the first run of 'condition'.
// 'condition' will always be invoked at least once.
func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error {
ctx, cancel := ContextForChannel(stopCh)
defer cancel()
return PollUntilWithContext(ctx, interval, condition.WithContext())
return PollUntilWithContext(ContextForChannel(stopCh), interval, condition.WithContext())
}
// PollUntilWithContext tries a condition func until it returns true,
@ -533,9 +539,7 @@ func PollImmediateWithContext(ctx context.Context, interval, timeout time.Durati
// PollImmediateUntil runs the 'condition' before waiting for the interval.
// 'condition' will always be invoked at least once.
func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error {
ctx, cancel := ContextForChannel(stopCh)
defer cancel()
return PollImmediateUntilWithContext(ctx, interval, condition.WithContext())
return PollImmediateUntilWithContext(ContextForChannel(stopCh), interval, condition.WithContext())
}
// PollImmediateUntilWithContext tries a condition func until it returns true,

View File

@ -523,8 +523,7 @@ func Test_waitFor(t *testing.T) {
err := func() error {
done := make(chan struct{})
defer close(done)
ctx, cancel := ContextForChannel(done)
defer cancel()
ctx := ContextForChannel(done)
return waitForWithContext(ctx, ticker.WithContext(), c.F.WithContext())
}()
switch {
@ -547,8 +546,7 @@ func Test_waitForWithEarlyClosing_waitFunc(t *testing.T) {
stopCh := make(chan struct{})
defer close(stopCh)
ctx, cancel := ContextForChannel(stopCh)
defer cancel()
ctx := ContextForChannel(stopCh)
start := time.Now()
err := waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} {
c := make(chan struct{})
@ -575,8 +573,7 @@ func Test_waitForWithClosedChannel(t *testing.T) {
close(stopCh)
c := make(chan struct{})
defer close(c)
ctx, cancel := ContextForChannel(stopCh)
defer cancel()
ctx := ContextForChannel(stopCh)
start := time.Now()
err := waitForWithContext(ctx, func(_ context.Context) <-chan struct{} {
@ -699,8 +696,7 @@ func TestContextForChannel(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := ContextForChannel(parentCh)
defer cancel()
ctx := ContextForChannel(parentCh)
<-ctx.Done()
}()
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package options
import (
"context"
"fmt"
"net/http"
"strconv"
@ -228,8 +229,8 @@ func (s *EtcdOptions) Complete(
}
if len(s.EncryptionProviderConfigFilepath) != 0 {
ctxTransformers, closeTransformers := wait.ContextForChannel(stopCh)
ctxServer, _ := wait.ContextForChannel(stopCh) // explicitly ignore cancel here because we do not own the server's lifecycle
ctxServer := wait.ContextForChannel(stopCh)
ctxTransformers, closeTransformers := context.WithCancel(ctxServer)
encryptionConfiguration, err := encryptionconfig.LoadEncryptionConfig(ctxTransformers, s.EncryptionProviderConfigFilepath, s.EncryptionProviderConfigAutomaticReload)
if err != nil {

View File

@ -199,7 +199,7 @@ func Run(c *cloudcontrollerconfig.CompletedConfig, cloud cloudprovider.Interface
}
if !c.ComponentConfig.Generic.LeaderElection.LeaderElect {
ctx, _ := wait.ContextForChannel(stopCh)
ctx := wait.ContextForChannel(stopCh)
run(ctx, controllerInitializers)
<-stopCh
return nil