diff --git a/notifier.go b/notifier.go index 0c46d6a..ffe7da3 100644 --- a/notifier.go +++ b/notifier.go @@ -28,12 +28,14 @@ import ( type Waiter struct { key string - ch chan bool + + ctx context.Context + cancel context.CancelFunc } func (w *Waiter) Wait(ctx context.Context) error { select { - case <-w.ch: + case <-w.ctx.Done(): return nil case <-ctx.Done(): return ctx.Err() @@ -43,26 +45,42 @@ func (w *Waiter) Wait(ctx context.Context) error { type Notifier struct { sync.Mutex - waiters map[string]*Waiter + waiters map[string]*Waiter + waiterMap map[string]map[*Waiter]bool } func (n *Notifier) NewWaiter(key string) *Waiter { n.Lock() defer n.Unlock() - _, found := n.waiters[key] + waiter, found := n.waiters[key] if found { - panic("already waiting") + w := &Waiter{ + key: key, + ctx: waiter.ctx, + cancel: waiter.cancel, + } + n.waiterMap[key][w] = true + return w } - waiter := &Waiter{ - key: key, - ch: make(chan bool, 1), + ctx, cancel := context.WithCancel(context.Background()) + waiter = &Waiter{ + key: key, + ctx: ctx, + cancel: cancel, } if n.waiters == nil { n.waiters = make(map[string]*Waiter) } + if n.waiterMap == nil { + n.waiterMap = make(map[string]map[*Waiter]bool) + } n.waiters[key] = waiter + if _, found := n.waiterMap[key]; !found { + n.waiterMap[key] = make(map[*Waiter]bool) + } + n.waiterMap[key][waiter] = true return waiter } @@ -71,18 +89,24 @@ func (n *Notifier) Reset() { defer n.Unlock() for _, w := range n.waiters { - close(w.ch) + w.cancel() } n.waiters = nil + n.waiterMap = nil } func (n *Notifier) Release(w *Waiter) { n.Lock() defer n.Unlock() - if _, found := n.waiters[w.key]; found { - delete(n.waiters, w.key) - close(w.ch) + if waiters, found := n.waiterMap[w.key]; found { + if _, found := waiters[w]; found { + delete(waiters, w) + if len(waiters) == 0 { + delete(n.waiters, w.key) + w.cancel() + } + } } } @@ -91,10 +115,8 @@ func (n *Notifier) Notify(key string) { defer n.Unlock() if w, found := n.waiters[key]; found { - select { - case w.ch <- true: - default: - // Ignore, already notified - } + w.cancel() + delete(n.waiters, w.key) + delete(n.waiterMap, w.key) } } diff --git a/notifier_test.go b/notifier_test.go index bade2d7..6c983dd 100644 --- a/notifier_test.go +++ b/notifier_test.go @@ -79,6 +79,22 @@ func TestNotifierWaitClosed(t *testing.T) { } } +func TestNotifierWaitClosedMulti(t *testing.T) { + var notifier Notifier + + waiter1 := notifier.NewWaiter("foo") + waiter2 := notifier.NewWaiter("foo") + notifier.Release(waiter1) + notifier.Release(waiter2) + + if err := waiter1.Wait(context.Background()); err != nil { + t.Error(err) + } + if err := waiter2.Wait(context.Background()); err != nil { + t.Error(err) + } +} + func TestNotifierResetWillNotify(t *testing.T) { var notifier Notifier @@ -103,18 +119,32 @@ func TestNotifierResetWillNotify(t *testing.T) { func TestNotifierDuplicate(t *testing.T) { var notifier Notifier + var wgStart sync.WaitGroup + var wgEnd sync.WaitGroup - waiter := notifier.NewWaiter("foo") - defer notifier.Release(waiter) + for i := 0; i < 2; i++ { + wgStart.Add(1) + wgEnd.Add(1) - defer func() { - if e := recover(); e != nil { - if e.(string) != "already waiting" { - t.Errorf("Expected error about already waiting, got %+v", e) + go func() { + defer wgEnd.Done() + waiter := notifier.NewWaiter("foo") + defer notifier.Release(waiter) + + // Goroutine has created the waiter and is ready. + wgStart.Done() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := waiter.Wait(ctx); err != nil { + t.Error(err) } - } - }() + }() + } - // Creating a waiter for an existing key will panic. - notifier.NewWaiter("foo") + wgStart.Wait() + + time.Sleep(100 * time.Millisecond) + notifier.Notify("foo") + wgEnd.Wait() }