diff --git a/api_signaling.go b/api_signaling.go index d30ea6e..78754c8 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -95,6 +95,14 @@ func (m *ClientMessage) CheckValid() error { return nil } +func (m *ClientMessage) String() string { + data, err := json.Marshal(m) + if err != nil { + return fmt.Sprintf("Could not serialize %#v: %s", m, err) + } + return string(data) +} + func (m *ClientMessage) NewErrorServerMessage(e *Error) *ServerMessage { return &ServerMessage{ Id: m.Id, @@ -179,6 +187,14 @@ func (r *ServerMessage) IsParticipantsUpdate() bool { return true } +func (r *ServerMessage) String() string { + data, err := json.Marshal(r) + if err != nil { + return fmt.Sprintf("Could not serialize %#v: %s", r, err) + } + return string(data) +} + type Error struct { Code string `json:"code"` Message string `json:"message"` diff --git a/backend_configuration_test.go b/backend_configuration_test.go index 107c787..252d6e8 100644 --- a/backend_configuration_test.go +++ b/backend_configuration_test.go @@ -32,47 +32,59 @@ import ( func testUrls(t *testing.T, config *BackendConfiguration, valid_urls []string, invalid_urls []string) { for _, u := range valid_urls { - parsed, err := url.ParseRequestURI(u) - if err != nil { - t.Errorf("The url %s should be valid, got %s", u, err) - continue - } - if !config.IsUrlAllowed(parsed) { - t.Errorf("The url %s should be allowed", u) - } - if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) { - t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret)) - } + u := u + t.Run(u, func(t *testing.T) { + parsed, err := url.ParseRequestURI(u) + if err != nil { + t.Errorf("The url %s should be valid, got %s", u, err) + return + } + if !config.IsUrlAllowed(parsed) { + t.Errorf("The url %s should be allowed", u) + } + if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) { + t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret)) + } + }) } for _, u := range invalid_urls { - parsed, _ := url.ParseRequestURI(u) - if config.IsUrlAllowed(parsed) { - t.Errorf("The url %s should not be allowed", u) - } + u := u + t.Run(u, func(t *testing.T) { + parsed, _ := url.ParseRequestURI(u) + if config.IsUrlAllowed(parsed) { + t.Errorf("The url %s should not be allowed", u) + } + }) } } func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]string, invalid_urls []string) { for _, entry := range valid_urls { - u := entry[0] - parsed, err := url.ParseRequestURI(u) - if err != nil { - t.Errorf("The url %s should be valid, got %s", u, err) - continue - } - if !config.IsUrlAllowed(parsed) { - t.Errorf("The url %s should be allowed", u) - } - s := entry[1] - if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) { - t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret)) - } + entry := entry + t.Run(entry[0], func(t *testing.T) { + u := entry[0] + parsed, err := url.ParseRequestURI(u) + if err != nil { + t.Errorf("The url %s should be valid, got %s", u, err) + return + } + if !config.IsUrlAllowed(parsed) { + t.Errorf("The url %s should be allowed", u) + } + s := entry[1] + if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) { + t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret)) + } + }) } for _, u := range invalid_urls { - parsed, _ := url.ParseRequestURI(u) - if config.IsUrlAllowed(parsed) { - t.Errorf("The url %s should not be allowed", u) - } + u := u + t.Run(u, func(t *testing.T) { + parsed, _ := url.ParseRequestURI(u) + if config.IsUrlAllowed(parsed) { + t.Errorf("The url %s should not be allowed", u) + } + }) } } diff --git a/backend_server_test.go b/backend_server_test.go index 5e2849b..633bea8 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -106,6 +106,7 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil WaitForHub(ctx, t, hub) (nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t) + nats.Close() server.Close() } diff --git a/client.go b/client.go index 8358d59..943b465 100644 --- a/client.go +++ b/client.go @@ -100,9 +100,10 @@ type Client struct { mu sync.Mutex - closeChan chan bool - messagesDone sync.WaitGroup - messageChan chan *bytes.Buffer + closeChan chan bool + messagesDone sync.WaitGroup + messageChan chan *bytes.Buffer + messageProcessing uint32 OnLookupCountry func(*Client) string OnClosed func(*Client) @@ -183,9 +184,24 @@ func (c *Client) Close() { return } + c.mu.Lock() + if c.conn != nil { + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint + } + c.mu.Unlock() + + if atomic.LoadUint32(&c.messageProcessing) == 1 { + // Defer closing + atomic.StoreUint32(&c.closed, 2) + return + } + + c.doClose() +} + +func (c *Client) doClose() { c.closeChan <- true c.messagesDone.Wait() - close(c.messageChan) c.OnClosed(c) c.SetSession(nil) @@ -231,6 +247,7 @@ func (c *Client) SendMessage(message WritableClientMessage) bool { func (c *Client) ReadPump() { defer func() { c.Close() + close(c.messageChan) }() addr := c.RemoteAddr() @@ -304,7 +321,7 @@ func (c *Client) ReadPump() { } // Stop processing if the client was closed. - if atomic.LoadUint32(&c.closed) == 1 { + if atomic.LoadUint32(&c.closed) != 0 { bufferPool.Put(decodeBuffer) break } @@ -321,10 +338,16 @@ func (c *Client) processMessages() { break } + atomic.StoreUint32(&c.messageProcessing, 1) c.OnMessageReceived(c, buffer.Bytes()) + atomic.StoreUint32(&c.messageProcessing, 0) c.messagesDone.Done() bufferPool.Put(buffer) } + + if atomic.LoadUint32(&c.closed) == 2 { + c.doClose() + } } func (c *Client) writeInternal(message json.Marshaler) bool { diff --git a/clientsession.go b/clientsession.go index 8d0da87..11de3fc 100644 --- a/clientsession.go +++ b/clientsession.go @@ -46,6 +46,8 @@ var ( ) type ClientSession struct { + roomJoinTime int64 + running int32 hub *Hub privateId string @@ -289,12 +291,21 @@ func (s *ClientSession) IsExpired(now time.Time) bool { func (s *ClientSession) SetRoom(room *Room) { atomic.StorePointer(&s.room, unsafe.Pointer(room)) + if room != nil { + atomic.StoreInt64(&s.roomJoinTime, time.Now().UnixNano()) + } else { + atomic.StoreInt64(&s.roomJoinTime, 0) + } } func (s *ClientSession) GetRoom() *Room { return (*Room)(atomic.LoadPointer(&s.room)) } +func (s *ClientSession) getRoomJoinTime() time.Time { + return time.Unix(0, atomic.LoadInt64(&s.roomJoinTime)) +} + func (s *ClientSession) releaseMcuObjects() { if len(s.publishers) > 0 { go func(publishers map[string]McuPublisher) { @@ -815,6 +826,13 @@ func (s *ClientSession) processNatsMessage(msg *NatsMessage) *ServerMessage { // TODO(jojo): Only send all users if current session id has // changed its "inCall" flag to true. m.Changed = nil + } else if msg.Message.Event.Target == "room" { + // Can happen mostly during tests where an older room NATS message + // could be received by a subscriber that joined after it was sent. + if msg.SendTime.Before(s.getRoomJoinTime()) { + log.Printf("Message %+v was sent before room was joined, ignoring", msg.Message) + return nil + } } } diff --git a/clientsession_test.go b/clientsession_test.go index 028109c..1dadd3d 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -22,6 +22,7 @@ package signaling import ( + "strconv" "testing" ) @@ -111,10 +112,13 @@ func Test_permissionsEqual(t *testing.T) { equal: false, }, } - for _, test := range tests { - equal := permissionsEqual(test.a, test.b) - if equal != test.equal { - t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal]) - } + for idx, test := range tests { + test := test + t.Run(strconv.Itoa(idx), func(t *testing.T) { + equal := permissionsEqual(test.a, test.b) + if equal != test.equal { + t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal]) + } + }) } } diff --git a/geoip_test.go b/geoip_test.go index cfc4541..8413f29 100644 --- a/geoip_test.go +++ b/geoip_test.go @@ -42,15 +42,19 @@ func testGeoLookupReader(t *testing.T, reader *GeoLookup) { } for ip, expected := range tests { - country, err := reader.LookupCountry(net.ParseIP(ip)) - if err != nil { - t.Errorf("Could not lookup %s: %s", ip, err) - continue - } + ip := ip + expected := expected + t.Run(ip, func(t *testing.T) { + country, err := reader.LookupCountry(net.ParseIP(ip)) + if err != nil { + t.Errorf("Could not lookup %s: %s", ip, err) + return + } - if country != expected { - t.Errorf("Expected %s for %s, got %s", expected, ip, country) - } + if country != expected { + t.Errorf("Expected %s for %s, got %s", expected, ip, country) + } + }) } } @@ -106,17 +110,21 @@ func TestGeoLookupContinent(t *testing.T) { } for country, expected := range tests { - continents := LookupContinents(country) - if len(continents) != len(expected) { - t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected) - continue - } - for idx, c := range expected { - if continents[idx] != c { + country := country + expected := expected + t.Run(country, func(t *testing.T) { + continents := LookupContinents(country) + if len(continents) != len(expected) { t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected) - break + return } - } + for idx, c := range expected { + if continents[idx] != c { + t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected) + break + } + } + }) } } diff --git a/hub_test.go b/hub_test.go index 0a12df4..7904c1e 100644 --- a/hub_test.go +++ b/hub_test.go @@ -120,6 +120,7 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve WaitForHub(ctx, t, h) (nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t) + nats.Close() server.Close() } diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index 29f24cc..ef2ecce 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -76,11 +76,15 @@ func Test_sortConnectionsForCountry(t *testing.T) { } for country, test := range testcases { - sorted := sortConnectionsForCountry(test[0], country) - for idx, conn := range sorted { - if test[1][idx] != conn { - t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) + country := country + test := test + t.Run(country, func(t *testing.T) { + sorted := sortConnectionsForCountry(test[0], country) + for idx, conn := range sorted { + if test[1][idx] != conn { + t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) + } } - } + }) } } diff --git a/natsclient.go b/natsclient.go index 6925b57..a9e07ae 100644 --- a/natsclient.go +++ b/natsclient.go @@ -38,6 +38,8 @@ const ( ) type NatsMessage struct { + SendTime time.Time `json:"sendtime"` + Type string `json:"type"` Message *ServerMessage `json:"message,omitempty"` @@ -150,16 +152,18 @@ func (c *natsClient) PublishNats(subject string, message *NatsMessage) error { func (c *natsClient) PublishMessage(subject string, message *ServerMessage) error { msg := &NatsMessage{ - Type: "message", - Message: message, + SendTime: time.Now(), + Type: "message", + Message: message, } return c.PublishNats(subject, msg) } func (c *natsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error { msg := &NatsMessage{ - Type: "room", - Room: message, + SendTime: time.Now(), + Type: "room", + Room: message, } return c.PublishNats(subject, msg) } diff --git a/natsclient_loopback.go b/natsclient_loopback.go index e9c33d7..aaa1699 100644 --- a/natsclient_loopback.go +++ b/natsclient_loopback.go @@ -22,7 +22,9 @@ package signaling import ( + "container/list" "encoding/json" + "log" "strings" "sync" "time" @@ -33,90 +35,87 @@ import ( type LoopbackNatsClient struct { mu sync.Mutex subscriptions map[string]map[*loopbackNatsSubscription]bool + + stopping bool + wakeup sync.Cond + incoming list.List } func NewLoopbackNatsClient() (NatsClient, error) { - return &LoopbackNatsClient{ + client := &LoopbackNatsClient{ subscriptions: make(map[string]map[*loopbackNatsSubscription]bool), - }, nil + } + client.wakeup.L = &client.mu + go client.processMessages() + return client, nil +} + +func (c *LoopbackNatsClient) processMessages() { + c.mu.Lock() + defer c.mu.Unlock() + for { + for !c.stopping && c.incoming.Len() == 0 { + c.wakeup.Wait() + } + if c.stopping { + break + } + + msg := c.incoming.Remove(c.incoming.Front()).(*nats.Msg) + c.processMessage(msg) + } +} + +func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) { + subs, found := c.subscriptions[msg.Subject] + if !found { + return + } + + channels := make([]chan *nats.Msg, 0, len(subs)) + for sub := range subs { + channels = append(channels, sub.ch) + } + c.mu.Unlock() + defer c.mu.Lock() + for _, ch := range channels { + select { + case ch <- msg: + default: + log.Printf("Slow consumer %s, dropping message", msg.Subject) + } + } } func (c *LoopbackNatsClient) Close() { c.mu.Lock() defer c.mu.Unlock() - for _, subs := range c.subscriptions { - for sub := range subs { - sub.Unsubscribe() // nolint - } - } - c.subscriptions = nil + c.stopping = true + c.incoming.Init() + c.wakeup.Signal() } type loopbackNatsSubscription struct { - subject string - client *LoopbackNatsClient - ch chan *nats.Msg - incoming []*nats.Msg - cond sync.Cond - quit bool + subject string + client *LoopbackNatsClient + + ch chan *nats.Msg } func (s *loopbackNatsSubscription) Unsubscribe() error { - s.cond.L.Lock() - if !s.quit { - s.quit = true - s.cond.Signal() - } - s.cond.L.Unlock() - s.client.unsubscribe(s) return nil } -func (s *loopbackNatsSubscription) queue(msg *nats.Msg) { - s.cond.L.Lock() - s.incoming = append(s.incoming, msg) - if len(s.incoming) == 1 { - s.cond.Signal() - } - s.cond.L.Unlock() -} - -func (s *loopbackNatsSubscription) run() { - s.cond.L.Lock() - defer s.cond.L.Unlock() - for !s.quit { - for !s.quit && len(s.incoming) == 0 { - s.cond.Wait() - } - - for !s.quit && len(s.incoming) > 0 { - msg := s.incoming[0] - s.incoming = s.incoming[1:] - s.cond.L.Unlock() - // A "real" NATS server would take some time to process the request, - // simulate this by sleeping a tiny bit. - time.Sleep(time.Millisecond) - s.ch <- msg - s.cond.L.Lock() - } - } -} - func (c *LoopbackNatsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) { - c.mu.Lock() - defer c.mu.Unlock() - - return c.subscribe(subject, ch) -} - -func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) { if strings.HasSuffix(subject, ".") || strings.Contains(subject, " ") { return nil, nats.ErrBadSubject } + c.mu.Lock() + defer c.mu.Unlock() if c.subscriptions == nil { return nil, nats.ErrConnectionClosed } @@ -126,7 +125,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS client: c, ch: ch, } - s.cond.L = &sync.Mutex{} subs, found := c.subscriptions[subject] if !found { subs = make(map[*loopbackNatsSubscription]bool) @@ -134,7 +132,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS } subs[s] = true - go s.run() return s, nil } @@ -161,18 +158,15 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error return nats.ErrConnectionClosed } - if subs, found := c.subscriptions[subject]; found { - msg := &nats.Msg{ - Subject: subject, - } - var err error - if msg.Data, err = json.Marshal(message); err != nil { - return err - } - for s := range subs { - s.queue(msg) - } + msg := &nats.Msg{ + Subject: subject, } + var err error + if msg.Data, err = json.Marshal(message); err != nil { + return err + } + c.incoming.PushBack(msg) + c.wakeup.Signal() return nil } @@ -182,16 +176,18 @@ func (c *LoopbackNatsClient) PublishNats(subject string, message *NatsMessage) e func (c *LoopbackNatsClient) PublishMessage(subject string, message *ServerMessage) error { msg := &NatsMessage{ - Type: "message", - Message: message, + SendTime: time.Now(), + Type: "message", + Message: message, } return c.PublishNats(subject, msg) } func (c *LoopbackNatsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error { msg := &NatsMessage{ - Type: "room", - Room: message, + SendTime: time.Now(), + Type: "room", + Room: message, } return c.PublishNats(subject, msg) } diff --git a/natsclient_loopback_test.go b/natsclient_loopback_test.go index 865498a..99aad5b 100644 --- a/natsclient_loopback_test.go +++ b/natsclient_loopback_test.go @@ -48,17 +48,20 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t } } -func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { +func CreateLoopbackNatsClientForTest(t *testing.T) (NatsClient, func()) { result, err := NewLoopbackNatsClient() if err != nil { t.Fatal(err) } - return result + return result, func() { + result.Close() + } } func TestLoopbackNatsClient_Subscribe(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_Subscribe(t, client) }) @@ -66,7 +69,8 @@ func TestLoopbackNatsClient_Subscribe(t *testing.T) { func TestLoopbackClient_PublishAfterClose(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_PublishAfterClose(t, client) }) @@ -74,7 +78,8 @@ func TestLoopbackClient_PublishAfterClose(t *testing.T) { func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_SubscribeAfterClose(t, client) }) @@ -82,7 +87,8 @@ func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { func TestLoopbackClient_BadSubjects(t *testing.T) { ensureNoGoroutinesLeak(t, func() { - client := CreateLoopbackNatsClientForTest(t) + client, shutdown := CreateLoopbackNatsClientForTest(t) + defer shutdown() testNatsClient_BadSubjects(t, client) }) diff --git a/natsclient_test.go b/natsclient_test.go index 7afe06c..67f377b 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -90,7 +90,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { } // Allow NATS goroutines to process messages. - time.Sleep(time.Millisecond) + time.Sleep(10 * time.Millisecond) } <-ch diff --git a/testclient_test.go b/testclient_test.go index 63ffc78..9dc275a 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -227,7 +227,13 @@ func (c *TestClient) CloseWithBye() { } func (c *TestClient) Close() { - c.conn.WriteMessage(websocket.CloseMessage, []byte{}) // nolint + if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err == websocket.ErrCloseSent { + // Already closed + return + } + + // Wait a bit for close message to be processed. + time.Sleep(100 * time.Millisecond) c.conn.Close() // Drain any entries in the channels to terminate the read goroutine.