diff --git a/notifier.go b/notifier.go index 94af5bd..3466f45 100644 --- a/notifier.go +++ b/notifier.go @@ -29,7 +29,11 @@ import ( type Waiter struct { key string - SingleWaiter + sw *SingleWaiter +} + +func (w *Waiter) Wait(ctx context.Context) error { + return w.sw.Wait(ctx) } type Notifier struct { @@ -47,22 +51,15 @@ func (n *Notifier) NewWaiter(key string) *Waiter { if found { w := &Waiter{ key: key, - SingleWaiter: SingleWaiter{ - ctx: waiter.ctx, - cancel: waiter.cancel, - }, + sw: waiter.sw, } n.waiterMap[key][w] = true return w } - ctx, cancel := context.WithCancel(context.Background()) waiter = &Waiter{ key: key, - SingleWaiter: SingleWaiter{ - ctx: ctx, - cancel: cancel, - }, + sw: newSingleWaiter(), } if n.waiters == nil { n.waiters = make(map[string]*Waiter) @@ -83,7 +80,7 @@ func (n *Notifier) Reset() { defer n.Unlock() for _, w := range n.waiters { - w.cancel() + w.sw.cancel() } n.waiters = nil n.waiterMap = nil @@ -98,7 +95,7 @@ func (n *Notifier) Release(w *Waiter) { delete(waiters, w) if len(waiters) == 0 { delete(n.waiters, w.key) - w.cancel() + w.sw.cancel() } } } @@ -109,7 +106,7 @@ func (n *Notifier) Notify(key string) { defer n.Unlock() if w, found := n.waiters[key]; found { - w.cancel() + w.sw.cancel() delete(n.waiters, w.key) delete(n.waiterMap, w.key) } diff --git a/single_notifier.go b/single_notifier.go index 91c4b6f..921542a 100644 --- a/single_notifier.go +++ b/single_notifier.go @@ -27,19 +27,43 @@ import ( ) type SingleWaiter struct { - ctx context.Context - cancel context.CancelFunc + root bool + ch chan struct{} + once sync.Once +} + +func newSingleWaiter() *SingleWaiter { + return &SingleWaiter{ + root: true, + ch: make(chan struct{}), + } +} + +func (w *SingleWaiter) subWaiter() *SingleWaiter { + return &SingleWaiter{ + ch: w.ch, + } } func (w *SingleWaiter) Wait(ctx context.Context) error { select { - case <-w.ctx.Done(): + case <-w.ch: return nil case <-ctx.Done(): return ctx.Err() } } +func (w *SingleWaiter) cancel() { + if !w.root { + return + } + + w.once.Do(func() { + close(w.ch) + }) +} + type SingleNotifier struct { sync.Mutex @@ -52,21 +76,14 @@ func (n *SingleNotifier) NewWaiter() *SingleWaiter { defer n.Unlock() if n.waiter == nil { - ctx, cancel := context.WithCancel(context.Background()) - n.waiter = &SingleWaiter{ - ctx: ctx, - cancel: cancel, - } + n.waiter = newSingleWaiter() } if n.waiters == nil { n.waiters = make(map[*SingleWaiter]bool) } - w := &SingleWaiter{ - ctx: n.waiter.ctx, - cancel: n.waiter.cancel, - } + w := n.waiter.subWaiter() n.waiters[w] = true return w }