Simplify loopback NATS client.

Only use one goroutine per client instead of one per subscription.
This ensures that (like with the "real" client), all messages are
processed in order across different subscriptions.
This commit is contained in:
Joachim Bauch 2021-06-07 16:04:07 +02:00
parent fe95626f3b
commit c8886d03c9
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
5 changed files with 80 additions and 79 deletions

View file

@ -106,6 +106,7 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil
WaitForHub(ctx, t, hub) WaitForHub(ctx, t, hub)
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t) (nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
nats.Close()
server.Close() server.Close()
} }

View file

@ -120,6 +120,7 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve
WaitForHub(ctx, t, h) WaitForHub(ctx, t, h)
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t) (nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
nats.Close()
server.Close() server.Close()
} }

View file

@ -22,10 +22,11 @@
package signaling package signaling
import ( import (
"container/list"
"encoding/json" "encoding/json"
"log"
"strings" "strings"
"sync" "sync"
"time"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
) )
@ -33,90 +34,87 @@ import (
type LoopbackNatsClient struct { type LoopbackNatsClient struct {
mu sync.Mutex mu sync.Mutex
subscriptions map[string]map[*loopbackNatsSubscription]bool subscriptions map[string]map[*loopbackNatsSubscription]bool
stopping bool
wakeup sync.Cond
incoming list.List
} }
func NewLoopbackNatsClient() (NatsClient, error) { func NewLoopbackNatsClient() (NatsClient, error) {
return &LoopbackNatsClient{ client := &LoopbackNatsClient{
subscriptions: make(map[string]map[*loopbackNatsSubscription]bool), subscriptions: make(map[string]map[*loopbackNatsSubscription]bool),
}, nil }
client.wakeup.L = &client.mu
go client.processMessages()
return client, nil
}
func (c *LoopbackNatsClient) processMessages() {
c.mu.Lock()
defer c.mu.Unlock()
for {
for !c.stopping && c.incoming.Len() == 0 {
c.wakeup.Wait()
}
if c.stopping {
break
}
msg := c.incoming.Remove(c.incoming.Front()).(*nats.Msg)
c.processMessage(msg)
}
}
func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) {
subs, found := c.subscriptions[msg.Subject]
if !found {
return
}
channels := make([]chan *nats.Msg, 0, len(subs))
for sub := range subs {
channels = append(channels, sub.ch)
}
c.mu.Unlock()
defer c.mu.Lock()
for _, ch := range channels {
select {
case ch <- msg:
default:
log.Printf("Slow consumer %s, dropping message", msg.Subject)
}
}
} }
func (c *LoopbackNatsClient) Close() { func (c *LoopbackNatsClient) Close() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
for _, subs := range c.subscriptions {
for sub := range subs {
sub.Unsubscribe() // nolint
}
}
c.subscriptions = nil c.subscriptions = nil
c.stopping = true
c.incoming.Init()
c.wakeup.Signal()
} }
type loopbackNatsSubscription struct { type loopbackNatsSubscription struct {
subject string subject string
client *LoopbackNatsClient client *LoopbackNatsClient
ch chan *nats.Msg ch chan *nats.Msg
incoming []*nats.Msg
cond sync.Cond
quit bool
} }
func (s *loopbackNatsSubscription) Unsubscribe() error { func (s *loopbackNatsSubscription) Unsubscribe() error {
s.cond.L.Lock()
if !s.quit {
s.quit = true
s.cond.Signal()
}
s.cond.L.Unlock()
s.client.unsubscribe(s) s.client.unsubscribe(s)
return nil return nil
} }
func (s *loopbackNatsSubscription) queue(msg *nats.Msg) {
s.cond.L.Lock()
s.incoming = append(s.incoming, msg)
if len(s.incoming) == 1 {
s.cond.Signal()
}
s.cond.L.Unlock()
}
func (s *loopbackNatsSubscription) run() {
s.cond.L.Lock()
defer s.cond.L.Unlock()
for !s.quit {
for !s.quit && len(s.incoming) == 0 {
s.cond.Wait()
}
for !s.quit && len(s.incoming) > 0 {
msg := s.incoming[0]
s.incoming = s.incoming[1:]
s.cond.L.Unlock()
// A "real" NATS server would take some time to process the request,
// simulate this by sleeping a tiny bit.
time.Sleep(time.Millisecond)
s.ch <- msg
s.cond.L.Lock()
}
}
}
func (c *LoopbackNatsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) { func (c *LoopbackNatsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) {
c.mu.Lock()
defer c.mu.Unlock()
return c.subscribe(subject, ch)
}
func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) {
if strings.HasSuffix(subject, ".") || strings.Contains(subject, " ") { if strings.HasSuffix(subject, ".") || strings.Contains(subject, " ") {
return nil, nats.ErrBadSubject return nil, nats.ErrBadSubject
} }
c.mu.Lock()
defer c.mu.Unlock()
if c.subscriptions == nil { if c.subscriptions == nil {
return nil, nats.ErrConnectionClosed return nil, nats.ErrConnectionClosed
} }
@ -126,7 +124,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS
client: c, client: c,
ch: ch, ch: ch,
} }
s.cond.L = &sync.Mutex{}
subs, found := c.subscriptions[subject] subs, found := c.subscriptions[subject]
if !found { if !found {
subs = make(map[*loopbackNatsSubscription]bool) subs = make(map[*loopbackNatsSubscription]bool)
@ -134,7 +131,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS
} }
subs[s] = true subs[s] = true
go s.run()
return s, nil return s, nil
} }
@ -161,7 +157,6 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error
return nats.ErrConnectionClosed return nats.ErrConnectionClosed
} }
if subs, found := c.subscriptions[subject]; found {
msg := &nats.Msg{ msg := &nats.Msg{
Subject: subject, Subject: subject,
} }
@ -169,10 +164,8 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error
if msg.Data, err = json.Marshal(message); err != nil { if msg.Data, err = json.Marshal(message); err != nil {
return err return err
} }
for s := range subs { c.incoming.PushBack(msg)
s.queue(msg) c.wakeup.Signal()
}
}
return nil return nil
} }

View file

@ -48,17 +48,20 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t
} }
} }
func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { func CreateLoopbackNatsClientForTest(t *testing.T) (NatsClient, func()) {
result, err := NewLoopbackNatsClient() result, err := NewLoopbackNatsClient()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return result return result, func() {
result.Close()
}
} }
func TestLoopbackNatsClient_Subscribe(t *testing.T) { func TestLoopbackNatsClient_Subscribe(t *testing.T) {
ensureNoGoroutinesLeak(t, func() { ensureNoGoroutinesLeak(t, func() {
client := CreateLoopbackNatsClientForTest(t) client, shutdown := CreateLoopbackNatsClientForTest(t)
defer shutdown()
testNatsClient_Subscribe(t, client) testNatsClient_Subscribe(t, client)
}) })
@ -66,7 +69,8 @@ func TestLoopbackNatsClient_Subscribe(t *testing.T) {
func TestLoopbackClient_PublishAfterClose(t *testing.T) { func TestLoopbackClient_PublishAfterClose(t *testing.T) {
ensureNoGoroutinesLeak(t, func() { ensureNoGoroutinesLeak(t, func() {
client := CreateLoopbackNatsClientForTest(t) client, shutdown := CreateLoopbackNatsClientForTest(t)
defer shutdown()
testNatsClient_PublishAfterClose(t, client) testNatsClient_PublishAfterClose(t, client)
}) })
@ -74,7 +78,8 @@ func TestLoopbackClient_PublishAfterClose(t *testing.T) {
func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { func TestLoopbackClient_SubscribeAfterClose(t *testing.T) {
ensureNoGoroutinesLeak(t, func() { ensureNoGoroutinesLeak(t, func() {
client := CreateLoopbackNatsClientForTest(t) client, shutdown := CreateLoopbackNatsClientForTest(t)
defer shutdown()
testNatsClient_SubscribeAfterClose(t, client) testNatsClient_SubscribeAfterClose(t, client)
}) })
@ -82,7 +87,8 @@ func TestLoopbackClient_SubscribeAfterClose(t *testing.T) {
func TestLoopbackClient_BadSubjects(t *testing.T) { func TestLoopbackClient_BadSubjects(t *testing.T) {
ensureNoGoroutinesLeak(t, func() { ensureNoGoroutinesLeak(t, func() {
client := CreateLoopbackNatsClientForTest(t) client, shutdown := CreateLoopbackNatsClientForTest(t)
defer shutdown()
testNatsClient_BadSubjects(t, client) testNatsClient_BadSubjects(t, client)
}) })

View file

@ -90,7 +90,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
} }
// Allow NATS goroutines to process messages. // Allow NATS goroutines to process messages.
time.Sleep(time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
<-ch <-ch