diff --git a/backend_server_test.go b/backend_server_test.go index 5e2849b..633bea8 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -106,6 +106,7 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil WaitForHub(ctx, t, hub) (nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t) + nats.Close() server.Close() } diff --git a/hub_test.go b/hub_test.go index 0a12df4..7904c1e 100644 --- a/hub_test.go +++ b/hub_test.go @@ -120,6 +120,7 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve WaitForHub(ctx, t, h) (nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t) + nats.Close() server.Close() } diff --git a/natsclient_loopback.go b/natsclient_loopback.go index e9c33d7..7a7d819 100644 --- a/natsclient_loopback.go +++ b/natsclient_loopback.go @@ -22,10 +22,11 @@ package signaling import ( + "container/list" "encoding/json" + "log" "strings" "sync" - "time" "github.com/nats-io/nats.go" ) @@ -33,90 +34,87 @@ import ( type LoopbackNatsClient struct { mu sync.Mutex subscriptions map[string]map[*loopbackNatsSubscription]bool + + stopping bool + wakeup sync.Cond + incoming list.List } func NewLoopbackNatsClient() (NatsClient, error) { - return &LoopbackNatsClient{ + client := &LoopbackNatsClient{ subscriptions: make(map[string]map[*loopbackNatsSubscription]bool), - }, nil + } + client.wakeup.L = &client.mu + go client.processMessages() + return client, nil +} + +func (c *LoopbackNatsClient) processMessages() { + c.mu.Lock() + defer c.mu.Unlock() + for { + for !c.stopping && c.incoming.Len() == 0 { + c.wakeup.Wait() + } + if c.stopping { + break + } + + msg := c.incoming.Remove(c.incoming.Front()).(*nats.Msg) + c.processMessage(msg) + } +} + +func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) { + subs, found := c.subscriptions[msg.Subject] + if !found { + return + } + + channels := make([]chan *nats.Msg, 0, len(subs)) + for sub := range subs { + channels = append(channels, sub.ch) + } + c.mu.Unlock() + defer c.mu.Lock() + for _, ch := range channels { + select { + case ch <- msg: + default: + log.Printf("Slow consumer %s, dropping message", msg.Subject) + } + } } func (c *LoopbackNatsClient) Close() { c.mu.Lock() defer c.mu.Unlock() - for _, subs := range c.subscriptions { - for sub := range subs { - sub.Unsubscribe() // nolint - } - } - c.subscriptions = nil + c.stopping = true + c.incoming.Init() + c.wakeup.Signal() } type loopbackNatsSubscription struct { - subject string - client *LoopbackNatsClient - ch chan *nats.Msg - incoming []*nats.Msg - cond sync.Cond - quit bool + subject string + client *LoopbackNatsClient + + ch chan *nats.Msg } func (s *loopbackNatsSubscription) Unsubscribe() error { - s.cond.L.Lock() - if !s.quit { - s.quit = true - s.cond.Signal() - } - s.cond.L.Unlock() - s.client.unsubscribe(s) return nil } -func (s *loopbackNatsSubscription) queue(msg *nats.Msg) { - s.cond.L.Lock() - s.incoming = append(s.incoming, msg) - if len(s.incoming) == 1 { - s.cond.Signal() - } - s.cond.L.Unlock() -} - -func (s *loopbackNatsSubscription) run() { - s.cond.L.Lock() - defer s.cond.L.Unlock() - for !s.quit { - for !s.quit && len(s.incoming) == 0 { - s.cond.Wait() - } - - for !s.quit && len(s.incoming) > 0 { - msg := s.incoming[0] - s.incoming = s.incoming[1:] - s.cond.L.Unlock() - // A "real" NATS server would take some time to process the request, - // simulate this by sleeping a tiny bit. - time.Sleep(time.Millisecond) - s.ch <- msg - s.cond.L.Lock() - } - } -} - func (c *LoopbackNatsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) { - c.mu.Lock() - defer c.mu.Unlock() - - return c.subscribe(subject, ch) -} - -func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) { if strings.HasSuffix(subject, ".") || strings.Contains(subject, " ") { return nil, nats.ErrBadSubject } + c.mu.Lock() + defer c.mu.Unlock() if c.subscriptions == nil { return nil, nats.ErrConnectionClosed } @@ -126,7 +124,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS client: c, ch: ch, } - s.cond.L = &sync.Mutex{} subs, found := c.subscriptions[subject] if !found { subs = make(map[*loopbackNatsSubscription]bool) @@ -134,7 +131,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS } subs[s] = true - go s.run() return s, nil } @@ -161,18 +157,15 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error return nats.ErrConnectionClosed } - if subs, found := c.subscriptions[subject]; found { - msg := &nats.Msg{ - Subject: subject, - } - var err error - if msg.Data, err = json.Marshal(message); err != nil { - return err - } - for s := range subs { - s.queue(msg) - } + msg := &nats.Msg{ + Subject: subject, } + var err error + if msg.Data, err = json.Marshal(message); err != nil { + return err + } + c.incoming.PushBack(msg) + c.wakeup.Signal() return nil } diff --git a/natsclient_loopback_test.go b/natsclient_loopback_test.go index 865498a..99aad5b 100644 --- a/natsclient_loopback_test.go +++ b/natsclient_loopback_test.go @@ -48,17 +48,20 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t } } -func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { +func CreateLoopbackNatsClientForTest(t *testing.T) (NatsClient, func()) { result, err := NewLoopbackNatsClient() if err != nil { t.Fatal(err) } - return result + return result, func() { + result.Close() + } } func TestLoopbackNatsClient_Subscribe(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_Subscribe(t, client) }) @@ -66,7 +69,8 @@ func TestLoopbackNatsClient_Subscribe(t *testing.T) { func TestLoopbackClient_PublishAfterClose(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_PublishAfterClose(t, client) }) @@ -74,7 +78,8 @@ func TestLoopbackClient_PublishAfterClose(t *testing.T) { func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_SubscribeAfterClose(t, client) }) @@ -82,7 +87,8 @@ func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { func TestLoopbackClient_BadSubjects(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_BadSubjects(t, client) }) diff --git a/natsclient_test.go b/natsclient_test.go index 7afe06c..67f377b 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -90,7 +90,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { } // Allow NATS goroutines to process messages. - time.Sleep(time.Millisecond) + time.Sleep(10 * time.Millisecond) } <-ch