diff --git a/controller/generic_controller.go b/controller/generic_controller.go index 6444e560..88f2410d 100644 --- a/controller/generic_controller.go +++ b/controller/generic_controller.go @@ -123,6 +123,20 @@ func (g *genericController) Enqueue(namespace, name string) { } func (g *genericController) AddHandler(ctx context.Context, name string, handler HandlerFunc) { + t := getHandlerTransaction(ctx) + if t == nil { + g.addHandler(ctx, name, handler) + return + } + + go func() { + if t.shouldContinue() { + g.addHandler(ctx, name, handler) + } + }() +} + +func (g *genericController) addHandler(ctx context.Context, name string, handler HandlerFunc) { g.Lock() defer g.Unlock() diff --git a/controller/transaction.go b/controller/transaction.go new file mode 100644 index 00000000..aa206a71 --- /dev/null +++ b/controller/transaction.go @@ -0,0 +1,47 @@ +package controller + +import ( + "context" +) + +type hTransactionKey struct{} + +type HandlerTransaction struct { + context.Context + parent context.Context + done chan struct{} + result bool +} + +func (h *HandlerTransaction) shouldContinue() bool { + select { + case <-h.parent.Done(): + return false + case <-h.done: + return h.result + } +} + +func (h *HandlerTransaction) Commit() { + h.result = true + close(h.done) +} + +func (h *HandlerTransaction) Rollback() { + close(h.done) +} + +func NewHandlerTransaction(ctx context.Context) *HandlerTransaction { + ht := &HandlerTransaction{ + parent: ctx, + done: make(chan struct{}), + } + ctx = context.WithValue(ctx, hTransactionKey{}, ht) + ht.Context = ctx + return ht +} + +func getHandlerTransaction(ctx context.Context) *HandlerTransaction { + v, _ := ctx.Value(hTransactionKey{}).(*HandlerTransaction) + return v +}