From c8886d03c978cd1887e60619ae9da98556127156 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 7 Jun 2021 16:04:07 +0200 Subject: [PATCH] Simplify loopback NATS client. Only use one goroutine per client instead of one per subscription. This ensures that (like with the "real" client), all messages are processed in order across different subscriptions. --- backend_server_test.go | 1 + hub_test.go | 1 + natsclient_loopback.go | 137 +++++++++++++++++------------------- natsclient_loopback_test.go | 18 +++-- natsclient_test.go | 2 +- 5 files changed, 80 insertions(+), 79 deletions(-) diff --git a/backend_server_test.go b/backend_server_test.go index 5e2849b..633bea8 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -106,6 +106,7 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil WaitForHub(ctx, t, hub) (nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t) + nats.Close() server.Close() } diff --git a/hub_test.go b/hub_test.go index 0a12df4..7904c1e 100644 --- a/hub_test.go +++ b/hub_test.go @@ -120,6 +120,7 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve WaitForHub(ctx, t, h) (nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t) + nats.Close() server.Close() } diff --git a/natsclient_loopback.go b/natsclient_loopback.go index e9c33d7..7a7d819 100644 --- a/natsclient_loopback.go +++ b/natsclient_loopback.go @@ -22,10 +22,11 @@ package signaling import ( + "container/list" "encoding/json" + "log" "strings" "sync" - "time" "github.com/nats-io/nats.go" ) @@ -33,90 +34,87 @@ import ( type LoopbackNatsClient struct { mu sync.Mutex subscriptions map[string]map[*loopbackNatsSubscription]bool + + stopping bool + wakeup sync.Cond + incoming list.List } func NewLoopbackNatsClient() (NatsClient, error) { - return &LoopbackNatsClient{ + client := &LoopbackNatsClient{ subscriptions: make(map[string]map[*loopbackNatsSubscription]bool), - }, nil + } + client.wakeup.L = &client.mu + go client.processMessages() + return client, nil +} + +func (c *LoopbackNatsClient) processMessages() { + c.mu.Lock() + defer c.mu.Unlock() + for { + for !c.stopping && c.incoming.Len() == 0 { + c.wakeup.Wait() + } + if c.stopping { + break + } + + msg := c.incoming.Remove(c.incoming.Front()).(*nats.Msg) + c.processMessage(msg) + } +} + +func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) { + subs, found := c.subscriptions[msg.Subject] + if !found { + return + } + + channels := make([]chan *nats.Msg, 0, len(subs)) + for sub := range subs { + channels = append(channels, sub.ch) + } + c.mu.Unlock() + defer c.mu.Lock() + for _, ch := range channels { + select { + case ch <- msg: + default: + log.Printf("Slow consumer %s, dropping message", msg.Subject) + } + } } func (c *LoopbackNatsClient) Close() { c.mu.Lock() defer c.mu.Unlock() - for _, subs := range c.subscriptions { - for sub := range subs { - sub.Unsubscribe() // nolint - } - } - c.subscriptions = nil + c.stopping = true + c.incoming.Init() + c.wakeup.Signal() } type loopbackNatsSubscription struct { - subject string - client *LoopbackNatsClient - ch chan *nats.Msg - incoming []*nats.Msg - cond sync.Cond - quit bool + subject string + client *LoopbackNatsClient + + ch chan *nats.Msg } func (s *loopbackNatsSubscription) Unsubscribe() error { - s.cond.L.Lock() - if !s.quit { - s.quit = true - s.cond.Signal() - } - s.cond.L.Unlock() - s.client.unsubscribe(s) return nil } -func (s *loopbackNatsSubscription) queue(msg *nats.Msg) { - s.cond.L.Lock() - s.incoming = append(s.incoming, msg) - if len(s.incoming) == 1 { - s.cond.Signal() - } - s.cond.L.Unlock() -} - -func (s *loopbackNatsSubscription) run() { - s.cond.L.Lock() - defer s.cond.L.Unlock() - for !s.quit { - for !s.quit && len(s.incoming) == 0 { - s.cond.Wait() - } - - for !s.quit && len(s.incoming) > 0 { - msg := s.incoming[0] - s.incoming = s.incoming[1:] - s.cond.L.Unlock() - // A "real" NATS server would take some time to process the request, - // simulate this by sleeping a tiny bit. - time.Sleep(time.Millisecond) - s.ch <- msg - s.cond.L.Lock() - } - } -} - func (c *LoopbackNatsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) { - c.mu.Lock() - defer c.mu.Unlock() - - return c.subscribe(subject, ch) -} - -func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) { if strings.HasSuffix(subject, ".") || strings.Contains(subject, " ") { return nil, nats.ErrBadSubject } + c.mu.Lock() + defer c.mu.Unlock() if c.subscriptions == nil { return nil, nats.ErrConnectionClosed } @@ -126,7 +124,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS client: c, ch: ch, } - s.cond.L = &sync.Mutex{} subs, found := c.subscriptions[subject] if !found { subs = make(map[*loopbackNatsSubscription]bool) @@ -134,7 +131,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS } subs[s] = true - go s.run() return s, nil } @@ -161,18 +157,15 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error return nats.ErrConnectionClosed } - if subs, found := c.subscriptions[subject]; found { - msg := &nats.Msg{ - Subject: subject, - } - var err error - if msg.Data, err = json.Marshal(message); err != nil { - return err - } - for s := range subs { - s.queue(msg) - } + msg := &nats.Msg{ + Subject: subject, } + var err error + if msg.Data, err = json.Marshal(message); err != nil { + return err + } + c.incoming.PushBack(msg) + c.wakeup.Signal() return nil } diff --git a/natsclient_loopback_test.go b/natsclient_loopback_test.go index 865498a..99aad5b 100644 --- a/natsclient_loopback_test.go +++ b/natsclient_loopback_test.go @@ -48,17 +48,20 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t } } -func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { +func CreateLoopbackNatsClientForTest(t *testing.T) (NatsClient, func()) { result, err := NewLoopbackNatsClient() if err != nil { t.Fatal(err) } - return result + return result, func() { + result.Close() + } } func TestLoopbackNatsClient_Subscribe(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_Subscribe(t, client) }) @@ -66,7 +69,8 @@ func TestLoopbackNatsClient_Subscribe(t *testing.T) { func TestLoopbackClient_PublishAfterClose(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_PublishAfterClose(t, client) }) @@ -74,7 +78,8 @@ func TestLoopbackClient_PublishAfterClose(t *testing.T) { func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_SubscribeAfterClose(t, client) }) @@ -82,7 +87,8 @@ func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { func TestLoopbackClient_BadSubjects(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_BadSubjects(t, client) }) diff --git a/natsclient_test.go b/natsclient_test.go index 7afe06c..67f377b 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -90,7 +90,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { } // Allow NATS goroutines to process messages. - time.Sleep(time.Millisecond) + time.Sleep(10 * time.Millisecond) } <-ch