mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-05-02 05:52:44 +02:00
Merge pull request #115 from strukturag/test-improvements
Various test improvements
This commit is contained in:
commit
35d3bf84e6
|
@ -95,6 +95,14 @@ func (m *ClientMessage) CheckValid() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *ClientMessage) String() string {
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("Could not serialize %#v: %s", m, err)
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *ClientMessage) NewErrorServerMessage(e *Error) *ServerMessage {
|
func (m *ClientMessage) NewErrorServerMessage(e *Error) *ServerMessage {
|
||||||
return &ServerMessage{
|
return &ServerMessage{
|
||||||
Id: m.Id,
|
Id: m.Id,
|
||||||
|
@ -179,6 +187,14 @@ func (r *ServerMessage) IsParticipantsUpdate() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *ServerMessage) String() string {
|
||||||
|
data, err := json.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("Could not serialize %#v: %s", r, err)
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
|
|
@ -32,47 +32,59 @@ import (
|
||||||
|
|
||||||
func testUrls(t *testing.T, config *BackendConfiguration, valid_urls []string, invalid_urls []string) {
|
func testUrls(t *testing.T, config *BackendConfiguration, valid_urls []string, invalid_urls []string) {
|
||||||
for _, u := range valid_urls {
|
for _, u := range valid_urls {
|
||||||
parsed, err := url.ParseRequestURI(u)
|
u := u
|
||||||
if err != nil {
|
t.Run(u, func(t *testing.T) {
|
||||||
t.Errorf("The url %s should be valid, got %s", u, err)
|
parsed, err := url.ParseRequestURI(u)
|
||||||
continue
|
if err != nil {
|
||||||
}
|
t.Errorf("The url %s should be valid, got %s", u, err)
|
||||||
if !config.IsUrlAllowed(parsed) {
|
return
|
||||||
t.Errorf("The url %s should be allowed", u)
|
}
|
||||||
}
|
if !config.IsUrlAllowed(parsed) {
|
||||||
if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) {
|
t.Errorf("The url %s should be allowed", u)
|
||||||
t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret))
|
}
|
||||||
}
|
if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) {
|
||||||
|
t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret))
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
for _, u := range invalid_urls {
|
for _, u := range invalid_urls {
|
||||||
parsed, _ := url.ParseRequestURI(u)
|
u := u
|
||||||
if config.IsUrlAllowed(parsed) {
|
t.Run(u, func(t *testing.T) {
|
||||||
t.Errorf("The url %s should not be allowed", u)
|
parsed, _ := url.ParseRequestURI(u)
|
||||||
}
|
if config.IsUrlAllowed(parsed) {
|
||||||
|
t.Errorf("The url %s should not be allowed", u)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]string, invalid_urls []string) {
|
func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]string, invalid_urls []string) {
|
||||||
for _, entry := range valid_urls {
|
for _, entry := range valid_urls {
|
||||||
u := entry[0]
|
entry := entry
|
||||||
parsed, err := url.ParseRequestURI(u)
|
t.Run(entry[0], func(t *testing.T) {
|
||||||
if err != nil {
|
u := entry[0]
|
||||||
t.Errorf("The url %s should be valid, got %s", u, err)
|
parsed, err := url.ParseRequestURI(u)
|
||||||
continue
|
if err != nil {
|
||||||
}
|
t.Errorf("The url %s should be valid, got %s", u, err)
|
||||||
if !config.IsUrlAllowed(parsed) {
|
return
|
||||||
t.Errorf("The url %s should be allowed", u)
|
}
|
||||||
}
|
if !config.IsUrlAllowed(parsed) {
|
||||||
s := entry[1]
|
t.Errorf("The url %s should be allowed", u)
|
||||||
if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) {
|
}
|
||||||
t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret))
|
s := entry[1]
|
||||||
}
|
if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) {
|
||||||
|
t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret))
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
for _, u := range invalid_urls {
|
for _, u := range invalid_urls {
|
||||||
parsed, _ := url.ParseRequestURI(u)
|
u := u
|
||||||
if config.IsUrlAllowed(parsed) {
|
t.Run(u, func(t *testing.T) {
|
||||||
t.Errorf("The url %s should not be allowed", u)
|
parsed, _ := url.ParseRequestURI(u)
|
||||||
}
|
if config.IsUrlAllowed(parsed) {
|
||||||
|
t.Errorf("The url %s should not be allowed", u)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -106,6 +106,7 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil
|
||||||
|
|
||||||
WaitForHub(ctx, t, hub)
|
WaitForHub(ctx, t, hub)
|
||||||
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
|
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
|
||||||
|
nats.Close()
|
||||||
server.Close()
|
server.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
33
client.go
33
client.go
|
@ -100,9 +100,10 @@ type Client struct {
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
closeChan chan bool
|
closeChan chan bool
|
||||||
messagesDone sync.WaitGroup
|
messagesDone sync.WaitGroup
|
||||||
messageChan chan *bytes.Buffer
|
messageChan chan *bytes.Buffer
|
||||||
|
messageProcessing uint32
|
||||||
|
|
||||||
OnLookupCountry func(*Client) string
|
OnLookupCountry func(*Client) string
|
||||||
OnClosed func(*Client)
|
OnClosed func(*Client)
|
||||||
|
@ -183,9 +184,24 @@ func (c *Client) Close() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if atomic.LoadUint32(&c.messageProcessing) == 1 {
|
||||||
|
// Defer closing
|
||||||
|
atomic.StoreUint32(&c.closed, 2)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.doClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doClose() {
|
||||||
c.closeChan <- true
|
c.closeChan <- true
|
||||||
c.messagesDone.Wait()
|
c.messagesDone.Wait()
|
||||||
close(c.messageChan)
|
|
||||||
|
|
||||||
c.OnClosed(c)
|
c.OnClosed(c)
|
||||||
c.SetSession(nil)
|
c.SetSession(nil)
|
||||||
|
@ -231,6 +247,7 @@ func (c *Client) SendMessage(message WritableClientMessage) bool {
|
||||||
func (c *Client) ReadPump() {
|
func (c *Client) ReadPump() {
|
||||||
defer func() {
|
defer func() {
|
||||||
c.Close()
|
c.Close()
|
||||||
|
close(c.messageChan)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
addr := c.RemoteAddr()
|
addr := c.RemoteAddr()
|
||||||
|
@ -304,7 +321,7 @@ func (c *Client) ReadPump() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop processing if the client was closed.
|
// Stop processing if the client was closed.
|
||||||
if atomic.LoadUint32(&c.closed) == 1 {
|
if atomic.LoadUint32(&c.closed) != 0 {
|
||||||
bufferPool.Put(decodeBuffer)
|
bufferPool.Put(decodeBuffer)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -321,10 +338,16 @@ func (c *Client) processMessages() {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
atomic.StoreUint32(&c.messageProcessing, 1)
|
||||||
c.OnMessageReceived(c, buffer.Bytes())
|
c.OnMessageReceived(c, buffer.Bytes())
|
||||||
|
atomic.StoreUint32(&c.messageProcessing, 0)
|
||||||
c.messagesDone.Done()
|
c.messagesDone.Done()
|
||||||
bufferPool.Put(buffer)
|
bufferPool.Put(buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if atomic.LoadUint32(&c.closed) == 2 {
|
||||||
|
c.doClose()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) writeInternal(message json.Marshaler) bool {
|
func (c *Client) writeInternal(message json.Marshaler) bool {
|
||||||
|
|
|
@ -46,6 +46,8 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClientSession struct {
|
type ClientSession struct {
|
||||||
|
roomJoinTime int64
|
||||||
|
|
||||||
running int32
|
running int32
|
||||||
hub *Hub
|
hub *Hub
|
||||||
privateId string
|
privateId string
|
||||||
|
@ -289,12 +291,21 @@ func (s *ClientSession) IsExpired(now time.Time) bool {
|
||||||
|
|
||||||
func (s *ClientSession) SetRoom(room *Room) {
|
func (s *ClientSession) SetRoom(room *Room) {
|
||||||
atomic.StorePointer(&s.room, unsafe.Pointer(room))
|
atomic.StorePointer(&s.room, unsafe.Pointer(room))
|
||||||
|
if room != nil {
|
||||||
|
atomic.StoreInt64(&s.roomJoinTime, time.Now().UnixNano())
|
||||||
|
} else {
|
||||||
|
atomic.StoreInt64(&s.roomJoinTime, 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClientSession) GetRoom() *Room {
|
func (s *ClientSession) GetRoom() *Room {
|
||||||
return (*Room)(atomic.LoadPointer(&s.room))
|
return (*Room)(atomic.LoadPointer(&s.room))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ClientSession) getRoomJoinTime() time.Time {
|
||||||
|
return time.Unix(0, atomic.LoadInt64(&s.roomJoinTime))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ClientSession) releaseMcuObjects() {
|
func (s *ClientSession) releaseMcuObjects() {
|
||||||
if len(s.publishers) > 0 {
|
if len(s.publishers) > 0 {
|
||||||
go func(publishers map[string]McuPublisher) {
|
go func(publishers map[string]McuPublisher) {
|
||||||
|
@ -815,6 +826,13 @@ func (s *ClientSession) processNatsMessage(msg *NatsMessage) *ServerMessage {
|
||||||
// TODO(jojo): Only send all users if current session id has
|
// TODO(jojo): Only send all users if current session id has
|
||||||
// changed its "inCall" flag to true.
|
// changed its "inCall" flag to true.
|
||||||
m.Changed = nil
|
m.Changed = nil
|
||||||
|
} else if msg.Message.Event.Target == "room" {
|
||||||
|
// Can happen mostly during tests where an older room NATS message
|
||||||
|
// could be received by a subscriber that joined after it was sent.
|
||||||
|
if msg.SendTime.Before(s.getRoomJoinTime()) {
|
||||||
|
log.Printf("Message %+v was sent before room was joined, ignoring", msg.Message)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
package signaling
|
package signaling
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -111,10 +112,13 @@ func Test_permissionsEqual(t *testing.T) {
|
||||||
equal: false,
|
equal: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for idx, test := range tests {
|
||||||
equal := permissionsEqual(test.a, test.b)
|
test := test
|
||||||
if equal != test.equal {
|
t.Run(strconv.Itoa(idx), func(t *testing.T) {
|
||||||
t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal])
|
equal := permissionsEqual(test.a, test.b)
|
||||||
}
|
if equal != test.equal {
|
||||||
|
t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal])
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,15 +42,19 @@ func testGeoLookupReader(t *testing.T, reader *GeoLookup) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for ip, expected := range tests {
|
for ip, expected := range tests {
|
||||||
country, err := reader.LookupCountry(net.ParseIP(ip))
|
ip := ip
|
||||||
if err != nil {
|
expected := expected
|
||||||
t.Errorf("Could not lookup %s: %s", ip, err)
|
t.Run(ip, func(t *testing.T) {
|
||||||
continue
|
country, err := reader.LookupCountry(net.ParseIP(ip))
|
||||||
}
|
if err != nil {
|
||||||
|
t.Errorf("Could not lookup %s: %s", ip, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if country != expected {
|
if country != expected {
|
||||||
t.Errorf("Expected %s for %s, got %s", expected, ip, country)
|
t.Errorf("Expected %s for %s, got %s", expected, ip, country)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,17 +110,21 @@ func TestGeoLookupContinent(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for country, expected := range tests {
|
for country, expected := range tests {
|
||||||
continents := LookupContinents(country)
|
country := country
|
||||||
if len(continents) != len(expected) {
|
expected := expected
|
||||||
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
t.Run(country, func(t *testing.T) {
|
||||||
continue
|
continents := LookupContinents(country)
|
||||||
}
|
if len(continents) != len(expected) {
|
||||||
for idx, c := range expected {
|
|
||||||
if continents[idx] != c {
|
|
||||||
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
||||||
break
|
return
|
||||||
}
|
}
|
||||||
}
|
for idx, c := range expected {
|
||||||
|
if continents[idx] != c {
|
||||||
|
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -120,6 +120,7 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve
|
||||||
|
|
||||||
WaitForHub(ctx, t, h)
|
WaitForHub(ctx, t, h)
|
||||||
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
|
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
|
||||||
|
nats.Close()
|
||||||
server.Close()
|
server.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -76,11 +76,15 @@ func Test_sortConnectionsForCountry(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for country, test := range testcases {
|
for country, test := range testcases {
|
||||||
sorted := sortConnectionsForCountry(test[0], country)
|
country := country
|
||||||
for idx, conn := range sorted {
|
test := test
|
||||||
if test[1][idx] != conn {
|
t.Run(country, func(t *testing.T) {
|
||||||
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
|
sorted := sortConnectionsForCountry(test[0], country)
|
||||||
|
for idx, conn := range sorted {
|
||||||
|
if test[1][idx] != conn {
|
||||||
|
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,8 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type NatsMessage struct {
|
type NatsMessage struct {
|
||||||
|
SendTime time.Time `json:"sendtime"`
|
||||||
|
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
Message *ServerMessage `json:"message,omitempty"`
|
Message *ServerMessage `json:"message,omitempty"`
|
||||||
|
@ -150,16 +152,18 @@ func (c *natsClient) PublishNats(subject string, message *NatsMessage) error {
|
||||||
|
|
||||||
func (c *natsClient) PublishMessage(subject string, message *ServerMessage) error {
|
func (c *natsClient) PublishMessage(subject string, message *ServerMessage) error {
|
||||||
msg := &NatsMessage{
|
msg := &NatsMessage{
|
||||||
Type: "message",
|
SendTime: time.Now(),
|
||||||
Message: message,
|
Type: "message",
|
||||||
|
Message: message,
|
||||||
}
|
}
|
||||||
return c.PublishNats(subject, msg)
|
return c.PublishNats(subject, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *natsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error {
|
func (c *natsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error {
|
||||||
msg := &NatsMessage{
|
msg := &NatsMessage{
|
||||||
Type: "room",
|
SendTime: time.Now(),
|
||||||
Room: message,
|
Type: "room",
|
||||||
|
Room: message,
|
||||||
}
|
}
|
||||||
return c.PublishNats(subject, msg)
|
return c.PublishNats(subject, msg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,9 @@
|
||||||
package signaling
|
package signaling
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"container/list"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -33,90 +35,87 @@ import (
|
||||||
type LoopbackNatsClient struct {
|
type LoopbackNatsClient struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
subscriptions map[string]map[*loopbackNatsSubscription]bool
|
subscriptions map[string]map[*loopbackNatsSubscription]bool
|
||||||
|
|
||||||
|
stopping bool
|
||||||
|
wakeup sync.Cond
|
||||||
|
incoming list.List
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLoopbackNatsClient() (NatsClient, error) {
|
func NewLoopbackNatsClient() (NatsClient, error) {
|
||||||
return &LoopbackNatsClient{
|
client := &LoopbackNatsClient{
|
||||||
subscriptions: make(map[string]map[*loopbackNatsSubscription]bool),
|
subscriptions: make(map[string]map[*loopbackNatsSubscription]bool),
|
||||||
}, nil
|
}
|
||||||
|
client.wakeup.L = &client.mu
|
||||||
|
go client.processMessages()
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LoopbackNatsClient) processMessages() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
for {
|
||||||
|
for !c.stopping && c.incoming.Len() == 0 {
|
||||||
|
c.wakeup.Wait()
|
||||||
|
}
|
||||||
|
if c.stopping {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := c.incoming.Remove(c.incoming.Front()).(*nats.Msg)
|
||||||
|
c.processMessage(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) {
|
||||||
|
subs, found := c.subscriptions[msg.Subject]
|
||||||
|
if !found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
channels := make([]chan *nats.Msg, 0, len(subs))
|
||||||
|
for sub := range subs {
|
||||||
|
channels = append(channels, sub.ch)
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
defer c.mu.Lock()
|
||||||
|
for _, ch := range channels {
|
||||||
|
select {
|
||||||
|
case ch <- msg:
|
||||||
|
default:
|
||||||
|
log.Printf("Slow consumer %s, dropping message", msg.Subject)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *LoopbackNatsClient) Close() {
|
func (c *LoopbackNatsClient) Close() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
for _, subs := range c.subscriptions {
|
|
||||||
for sub := range subs {
|
|
||||||
sub.Unsubscribe() // nolint
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.subscriptions = nil
|
c.subscriptions = nil
|
||||||
|
c.stopping = true
|
||||||
|
c.incoming.Init()
|
||||||
|
c.wakeup.Signal()
|
||||||
}
|
}
|
||||||
|
|
||||||
type loopbackNatsSubscription struct {
|
type loopbackNatsSubscription struct {
|
||||||
subject string
|
subject string
|
||||||
client *LoopbackNatsClient
|
client *LoopbackNatsClient
|
||||||
ch chan *nats.Msg
|
|
||||||
incoming []*nats.Msg
|
ch chan *nats.Msg
|
||||||
cond sync.Cond
|
|
||||||
quit bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *loopbackNatsSubscription) Unsubscribe() error {
|
func (s *loopbackNatsSubscription) Unsubscribe() error {
|
||||||
s.cond.L.Lock()
|
|
||||||
if !s.quit {
|
|
||||||
s.quit = true
|
|
||||||
s.cond.Signal()
|
|
||||||
}
|
|
||||||
s.cond.L.Unlock()
|
|
||||||
|
|
||||||
s.client.unsubscribe(s)
|
s.client.unsubscribe(s)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *loopbackNatsSubscription) queue(msg *nats.Msg) {
|
|
||||||
s.cond.L.Lock()
|
|
||||||
s.incoming = append(s.incoming, msg)
|
|
||||||
if len(s.incoming) == 1 {
|
|
||||||
s.cond.Signal()
|
|
||||||
}
|
|
||||||
s.cond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *loopbackNatsSubscription) run() {
|
|
||||||
s.cond.L.Lock()
|
|
||||||
defer s.cond.L.Unlock()
|
|
||||||
for !s.quit {
|
|
||||||
for !s.quit && len(s.incoming) == 0 {
|
|
||||||
s.cond.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
for !s.quit && len(s.incoming) > 0 {
|
|
||||||
msg := s.incoming[0]
|
|
||||||
s.incoming = s.incoming[1:]
|
|
||||||
s.cond.L.Unlock()
|
|
||||||
// A "real" NATS server would take some time to process the request,
|
|
||||||
// simulate this by sleeping a tiny bit.
|
|
||||||
time.Sleep(time.Millisecond)
|
|
||||||
s.ch <- msg
|
|
||||||
s.cond.L.Lock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LoopbackNatsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) {
|
func (c *LoopbackNatsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) {
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
return c.subscribe(subject, ch)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) {
|
|
||||||
if strings.HasSuffix(subject, ".") || strings.Contains(subject, " ") {
|
if strings.HasSuffix(subject, ".") || strings.Contains(subject, " ") {
|
||||||
return nil, nats.ErrBadSubject
|
return nil, nats.ErrBadSubject
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
if c.subscriptions == nil {
|
if c.subscriptions == nil {
|
||||||
return nil, nats.ErrConnectionClosed
|
return nil, nats.ErrConnectionClosed
|
||||||
}
|
}
|
||||||
|
@ -126,7 +125,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS
|
||||||
client: c,
|
client: c,
|
||||||
ch: ch,
|
ch: ch,
|
||||||
}
|
}
|
||||||
s.cond.L = &sync.Mutex{}
|
|
||||||
subs, found := c.subscriptions[subject]
|
subs, found := c.subscriptions[subject]
|
||||||
if !found {
|
if !found {
|
||||||
subs = make(map[*loopbackNatsSubscription]bool)
|
subs = make(map[*loopbackNatsSubscription]bool)
|
||||||
|
@ -134,7 +132,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS
|
||||||
}
|
}
|
||||||
subs[s] = true
|
subs[s] = true
|
||||||
|
|
||||||
go s.run()
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,18 +158,15 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error
|
||||||
return nats.ErrConnectionClosed
|
return nats.ErrConnectionClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
if subs, found := c.subscriptions[subject]; found {
|
msg := &nats.Msg{
|
||||||
msg := &nats.Msg{
|
Subject: subject,
|
||||||
Subject: subject,
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
if msg.Data, err = json.Marshal(message); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for s := range subs {
|
|
||||||
s.queue(msg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
var err error
|
||||||
|
if msg.Data, err = json.Marshal(message); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.incoming.PushBack(msg)
|
||||||
|
c.wakeup.Signal()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,16 +176,18 @@ func (c *LoopbackNatsClient) PublishNats(subject string, message *NatsMessage) e
|
||||||
|
|
||||||
func (c *LoopbackNatsClient) PublishMessage(subject string, message *ServerMessage) error {
|
func (c *LoopbackNatsClient) PublishMessage(subject string, message *ServerMessage) error {
|
||||||
msg := &NatsMessage{
|
msg := &NatsMessage{
|
||||||
Type: "message",
|
SendTime: time.Now(),
|
||||||
Message: message,
|
Type: "message",
|
||||||
|
Message: message,
|
||||||
}
|
}
|
||||||
return c.PublishNats(subject, msg)
|
return c.PublishNats(subject, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *LoopbackNatsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error {
|
func (c *LoopbackNatsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error {
|
||||||
msg := &NatsMessage{
|
msg := &NatsMessage{
|
||||||
Type: "room",
|
SendTime: time.Now(),
|
||||||
Room: message,
|
Type: "room",
|
||||||
|
Room: message,
|
||||||
}
|
}
|
||||||
return c.PublishNats(subject, msg)
|
return c.PublishNats(subject, msg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,17 +48,20 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient {
|
func CreateLoopbackNatsClientForTest(t *testing.T) (NatsClient, func()) {
|
||||||
result, err := NewLoopbackNatsClient()
|
result, err := NewLoopbackNatsClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
return result
|
return result, func() {
|
||||||
|
result.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoopbackNatsClient_Subscribe(t *testing.T) {
|
func TestLoopbackNatsClient_Subscribe(t *testing.T) {
|
||||||
ensureNoGoroutinesLeak(t, func() {
|
ensureNoGoroutinesLeak(t, func() {
|
||||||
client := CreateLoopbackNatsClientForTest(t)
|
client, shutdown := CreateLoopbackNatsClientForTest(t)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
testNatsClient_Subscribe(t, client)
|
testNatsClient_Subscribe(t, client)
|
||||||
})
|
})
|
||||||
|
@ -66,7 +69,8 @@ func TestLoopbackNatsClient_Subscribe(t *testing.T) {
|
||||||
|
|
||||||
func TestLoopbackClient_PublishAfterClose(t *testing.T) {
|
func TestLoopbackClient_PublishAfterClose(t *testing.T) {
|
||||||
ensureNoGoroutinesLeak(t, func() {
|
ensureNoGoroutinesLeak(t, func() {
|
||||||
client := CreateLoopbackNatsClientForTest(t)
|
client, shutdown := CreateLoopbackNatsClientForTest(t)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
testNatsClient_PublishAfterClose(t, client)
|
testNatsClient_PublishAfterClose(t, client)
|
||||||
})
|
})
|
||||||
|
@ -74,7 +78,8 @@ func TestLoopbackClient_PublishAfterClose(t *testing.T) {
|
||||||
|
|
||||||
func TestLoopbackClient_SubscribeAfterClose(t *testing.T) {
|
func TestLoopbackClient_SubscribeAfterClose(t *testing.T) {
|
||||||
ensureNoGoroutinesLeak(t, func() {
|
ensureNoGoroutinesLeak(t, func() {
|
||||||
client := CreateLoopbackNatsClientForTest(t)
|
client, shutdown := CreateLoopbackNatsClientForTest(t)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
testNatsClient_SubscribeAfterClose(t, client)
|
testNatsClient_SubscribeAfterClose(t, client)
|
||||||
})
|
})
|
||||||
|
@ -82,7 +87,8 @@ func TestLoopbackClient_SubscribeAfterClose(t *testing.T) {
|
||||||
|
|
||||||
func TestLoopbackClient_BadSubjects(t *testing.T) {
|
func TestLoopbackClient_BadSubjects(t *testing.T) {
|
||||||
ensureNoGoroutinesLeak(t, func() {
|
ensureNoGoroutinesLeak(t, func() {
|
||||||
client := CreateLoopbackNatsClientForTest(t)
|
client, shutdown := CreateLoopbackNatsClientForTest(t)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
testNatsClient_BadSubjects(t, client)
|
testNatsClient_BadSubjects(t, client)
|
||||||
})
|
})
|
||||||
|
|
|
@ -90,7 +90,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allow NATS goroutines to process messages.
|
// Allow NATS goroutines to process messages.
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
<-ch
|
<-ch
|
||||||
|
|
||||||
|
|
|
@ -227,7 +227,13 @@ func (c *TestClient) CloseWithBye() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TestClient) Close() {
|
func (c *TestClient) Close() {
|
||||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{}) // nolint
|
if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err == websocket.ErrCloseSent {
|
||||||
|
// Already closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a bit for close message to be processed.
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
|
|
||||||
// Drain any entries in the channels to terminate the read goroutine.
|
// Drain any entries in the channels to terminate the read goroutine.
|
||||||
|
|
Loading…
Reference in a new issue