Merge pull request #115 from strukturag/test-improvements
Various test improvements
This commit is contained in:
commit
35d3bf84e6
|
@ -95,6 +95,14 @@ func (m *ClientMessage) CheckValid() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *ClientMessage) String() string {
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not serialize %#v: %s", m, err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func (m *ClientMessage) NewErrorServerMessage(e *Error) *ServerMessage {
|
||||
return &ServerMessage{
|
||||
Id: m.Id,
|
||||
|
@ -179,6 +187,14 @@ func (r *ServerMessage) IsParticipantsUpdate() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (r *ServerMessage) String() string {
|
||||
data, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not serialize %#v: %s", r, err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
|
|
|
@ -32,47 +32,59 @@ import (
|
|||
|
||||
func testUrls(t *testing.T, config *BackendConfiguration, valid_urls []string, invalid_urls []string) {
|
||||
for _, u := range valid_urls {
|
||||
parsed, err := url.ParseRequestURI(u)
|
||||
if err != nil {
|
||||
t.Errorf("The url %s should be valid, got %s", u, err)
|
||||
continue
|
||||
}
|
||||
if !config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should be allowed", u)
|
||||
}
|
||||
if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) {
|
||||
t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret))
|
||||
}
|
||||
u := u
|
||||
t.Run(u, func(t *testing.T) {
|
||||
parsed, err := url.ParseRequestURI(u)
|
||||
if err != nil {
|
||||
t.Errorf("The url %s should be valid, got %s", u, err)
|
||||
return
|
||||
}
|
||||
if !config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should be allowed", u)
|
||||
}
|
||||
if secret := config.GetSecret(parsed); !bytes.Equal(secret, testBackendSecret) {
|
||||
t.Errorf("Expected secret %s for url %s, got %s", string(testBackendSecret), u, string(secret))
|
||||
}
|
||||
})
|
||||
}
|
||||
for _, u := range invalid_urls {
|
||||
parsed, _ := url.ParseRequestURI(u)
|
||||
if config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should not be allowed", u)
|
||||
}
|
||||
u := u
|
||||
t.Run(u, func(t *testing.T) {
|
||||
parsed, _ := url.ParseRequestURI(u)
|
||||
if config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should not be allowed", u)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]string, invalid_urls []string) {
|
||||
for _, entry := range valid_urls {
|
||||
u := entry[0]
|
||||
parsed, err := url.ParseRequestURI(u)
|
||||
if err != nil {
|
||||
t.Errorf("The url %s should be valid, got %s", u, err)
|
||||
continue
|
||||
}
|
||||
if !config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should be allowed", u)
|
||||
}
|
||||
s := entry[1]
|
||||
if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) {
|
||||
t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret))
|
||||
}
|
||||
entry := entry
|
||||
t.Run(entry[0], func(t *testing.T) {
|
||||
u := entry[0]
|
||||
parsed, err := url.ParseRequestURI(u)
|
||||
if err != nil {
|
||||
t.Errorf("The url %s should be valid, got %s", u, err)
|
||||
return
|
||||
}
|
||||
if !config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should be allowed", u)
|
||||
}
|
||||
s := entry[1]
|
||||
if secret := config.GetSecret(parsed); !bytes.Equal(secret, []byte(s)) {
|
||||
t.Errorf("Expected secret %s for url %s, got %s", string(s), u, string(secret))
|
||||
}
|
||||
})
|
||||
}
|
||||
for _, u := range invalid_urls {
|
||||
parsed, _ := url.ParseRequestURI(u)
|
||||
if config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should not be allowed", u)
|
||||
}
|
||||
u := u
|
||||
t.Run(u, func(t *testing.T) {
|
||||
parsed, _ := url.ParseRequestURI(u)
|
||||
if config.IsUrlAllowed(parsed) {
|
||||
t.Errorf("The url %s should not be allowed", u)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -106,6 +106,7 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil
|
|||
|
||||
WaitForHub(ctx, t, hub)
|
||||
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
|
||||
nats.Close()
|
||||
server.Close()
|
||||
}
|
||||
|
||||
|
|
33
client.go
33
client.go
|
@ -100,9 +100,10 @@ type Client struct {
|
|||
|
||||
mu sync.Mutex
|
||||
|
||||
closeChan chan bool
|
||||
messagesDone sync.WaitGroup
|
||||
messageChan chan *bytes.Buffer
|
||||
closeChan chan bool
|
||||
messagesDone sync.WaitGroup
|
||||
messageChan chan *bytes.Buffer
|
||||
messageProcessing uint32
|
||||
|
||||
OnLookupCountry func(*Client) string
|
||||
OnClosed func(*Client)
|
||||
|
@ -183,9 +184,24 @@ func (c *Client) Close() {
|
|||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if atomic.LoadUint32(&c.messageProcessing) == 1 {
|
||||
// Defer closing
|
||||
atomic.StoreUint32(&c.closed, 2)
|
||||
return
|
||||
}
|
||||
|
||||
c.doClose()
|
||||
}
|
||||
|
||||
func (c *Client) doClose() {
|
||||
c.closeChan <- true
|
||||
c.messagesDone.Wait()
|
||||
close(c.messageChan)
|
||||
|
||||
c.OnClosed(c)
|
||||
c.SetSession(nil)
|
||||
|
@ -231,6 +247,7 @@ func (c *Client) SendMessage(message WritableClientMessage) bool {
|
|||
func (c *Client) ReadPump() {
|
||||
defer func() {
|
||||
c.Close()
|
||||
close(c.messageChan)
|
||||
}()
|
||||
|
||||
addr := c.RemoteAddr()
|
||||
|
@ -304,7 +321,7 @@ func (c *Client) ReadPump() {
|
|||
}
|
||||
|
||||
// Stop processing if the client was closed.
|
||||
if atomic.LoadUint32(&c.closed) == 1 {
|
||||
if atomic.LoadUint32(&c.closed) != 0 {
|
||||
bufferPool.Put(decodeBuffer)
|
||||
break
|
||||
}
|
||||
|
@ -321,10 +338,16 @@ func (c *Client) processMessages() {
|
|||
break
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.messageProcessing, 1)
|
||||
c.OnMessageReceived(c, buffer.Bytes())
|
||||
atomic.StoreUint32(&c.messageProcessing, 0)
|
||||
c.messagesDone.Done()
|
||||
bufferPool.Put(buffer)
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&c.closed) == 2 {
|
||||
c.doClose()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) writeInternal(message json.Marshaler) bool {
|
||||
|
|
|
@ -46,6 +46,8 @@ var (
|
|||
)
|
||||
|
||||
type ClientSession struct {
|
||||
roomJoinTime int64
|
||||
|
||||
running int32
|
||||
hub *Hub
|
||||
privateId string
|
||||
|
@ -289,12 +291,21 @@ func (s *ClientSession) IsExpired(now time.Time) bool {
|
|||
|
||||
func (s *ClientSession) SetRoom(room *Room) {
|
||||
atomic.StorePointer(&s.room, unsafe.Pointer(room))
|
||||
if room != nil {
|
||||
atomic.StoreInt64(&s.roomJoinTime, time.Now().UnixNano())
|
||||
} else {
|
||||
atomic.StoreInt64(&s.roomJoinTime, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClientSession) GetRoom() *Room {
|
||||
return (*Room)(atomic.LoadPointer(&s.room))
|
||||
}
|
||||
|
||||
func (s *ClientSession) getRoomJoinTime() time.Time {
|
||||
return time.Unix(0, atomic.LoadInt64(&s.roomJoinTime))
|
||||
}
|
||||
|
||||
func (s *ClientSession) releaseMcuObjects() {
|
||||
if len(s.publishers) > 0 {
|
||||
go func(publishers map[string]McuPublisher) {
|
||||
|
@ -815,6 +826,13 @@ func (s *ClientSession) processNatsMessage(msg *NatsMessage) *ServerMessage {
|
|||
// TODO(jojo): Only send all users if current session id has
|
||||
// changed its "inCall" flag to true.
|
||||
m.Changed = nil
|
||||
} else if msg.Message.Event.Target == "room" {
|
||||
// Can happen mostly during tests where an older room NATS message
|
||||
// could be received by a subscriber that joined after it was sent.
|
||||
if msg.SendTime.Before(s.getRoomJoinTime()) {
|
||||
log.Printf("Message %+v was sent before room was joined, ignoring", msg.Message)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -111,10 +112,13 @@ func Test_permissionsEqual(t *testing.T) {
|
|||
equal: false,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
equal := permissionsEqual(test.a, test.b)
|
||||
if equal != test.equal {
|
||||
t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal])
|
||||
}
|
||||
for idx, test := range tests {
|
||||
test := test
|
||||
t.Run(strconv.Itoa(idx), func(t *testing.T) {
|
||||
equal := permissionsEqual(test.a, test.b)
|
||||
if equal != test.equal {
|
||||
t.Errorf("Expected %+v to be %s to %+v but was %s", test.a, equalStrings[test.equal], test.b, equalStrings[equal])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,15 +42,19 @@ func testGeoLookupReader(t *testing.T, reader *GeoLookup) {
|
|||
}
|
||||
|
||||
for ip, expected := range tests {
|
||||
country, err := reader.LookupCountry(net.ParseIP(ip))
|
||||
if err != nil {
|
||||
t.Errorf("Could not lookup %s: %s", ip, err)
|
||||
continue
|
||||
}
|
||||
ip := ip
|
||||
expected := expected
|
||||
t.Run(ip, func(t *testing.T) {
|
||||
country, err := reader.LookupCountry(net.ParseIP(ip))
|
||||
if err != nil {
|
||||
t.Errorf("Could not lookup %s: %s", ip, err)
|
||||
return
|
||||
}
|
||||
|
||||
if country != expected {
|
||||
t.Errorf("Expected %s for %s, got %s", expected, ip, country)
|
||||
}
|
||||
if country != expected {
|
||||
t.Errorf("Expected %s for %s, got %s", expected, ip, country)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,17 +110,21 @@ func TestGeoLookupContinent(t *testing.T) {
|
|||
}
|
||||
|
||||
for country, expected := range tests {
|
||||
continents := LookupContinents(country)
|
||||
if len(continents) != len(expected) {
|
||||
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
||||
continue
|
||||
}
|
||||
for idx, c := range expected {
|
||||
if continents[idx] != c {
|
||||
country := country
|
||||
expected := expected
|
||||
t.Run(country, func(t *testing.T) {
|
||||
continents := LookupContinents(country)
|
||||
if len(continents) != len(expected) {
|
||||
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
||||
break
|
||||
return
|
||||
}
|
||||
}
|
||||
for idx, c := range expected {
|
||||
if continents[idx] != c {
|
||||
t.Errorf("Continents didn't match for %s: got %s, expected %s", country, continents, expected)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -120,6 +120,7 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve
|
|||
|
||||
WaitForHub(ctx, t, h)
|
||||
(nats).(*LoopbackNatsClient).waitForSubscriptionsEmpty(ctx, t)
|
||||
nats.Close()
|
||||
server.Close()
|
||||
}
|
||||
|
||||
|
|
|
@ -76,11 +76,15 @@ func Test_sortConnectionsForCountry(t *testing.T) {
|
|||
}
|
||||
|
||||
for country, test := range testcases {
|
||||
sorted := sortConnectionsForCountry(test[0], country)
|
||||
for idx, conn := range sorted {
|
||||
if test[1][idx] != conn {
|
||||
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
|
||||
country := country
|
||||
test := test
|
||||
t.Run(country, func(t *testing.T) {
|
||||
sorted := sortConnectionsForCountry(test[0], country)
|
||||
for idx, conn := range sorted {
|
||||
if test[1][idx] != conn {
|
||||
t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,6 +38,8 @@ const (
|
|||
)
|
||||
|
||||
type NatsMessage struct {
|
||||
SendTime time.Time `json:"sendtime"`
|
||||
|
||||
Type string `json:"type"`
|
||||
|
||||
Message *ServerMessage `json:"message,omitempty"`
|
||||
|
@ -150,16 +152,18 @@ func (c *natsClient) PublishNats(subject string, message *NatsMessage) error {
|
|||
|
||||
func (c *natsClient) PublishMessage(subject string, message *ServerMessage) error {
|
||||
msg := &NatsMessage{
|
||||
Type: "message",
|
||||
Message: message,
|
||||
SendTime: time.Now(),
|
||||
Type: "message",
|
||||
Message: message,
|
||||
}
|
||||
return c.PublishNats(subject, msg)
|
||||
}
|
||||
|
||||
func (c *natsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error {
|
||||
msg := &NatsMessage{
|
||||
Type: "room",
|
||||
Room: message,
|
||||
SendTime: time.Now(),
|
||||
Type: "room",
|
||||
Room: message,
|
||||
}
|
||||
return c.PublishNats(subject, msg)
|
||||
}
|
||||
|
|
|
@ -22,7 +22,9 @@
|
|||
package signaling
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -33,90 +35,87 @@ import (
|
|||
type LoopbackNatsClient struct {
|
||||
mu sync.Mutex
|
||||
subscriptions map[string]map[*loopbackNatsSubscription]bool
|
||||
|
||||
stopping bool
|
||||
wakeup sync.Cond
|
||||
incoming list.List
|
||||
}
|
||||
|
||||
func NewLoopbackNatsClient() (NatsClient, error) {
|
||||
return &LoopbackNatsClient{
|
||||
client := &LoopbackNatsClient{
|
||||
subscriptions: make(map[string]map[*loopbackNatsSubscription]bool),
|
||||
}, nil
|
||||
}
|
||||
client.wakeup.L = &client.mu
|
||||
go client.processMessages()
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) processMessages() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for {
|
||||
for !c.stopping && c.incoming.Len() == 0 {
|
||||
c.wakeup.Wait()
|
||||
}
|
||||
if c.stopping {
|
||||
break
|
||||
}
|
||||
|
||||
msg := c.incoming.Remove(c.incoming.Front()).(*nats.Msg)
|
||||
c.processMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) {
|
||||
subs, found := c.subscriptions[msg.Subject]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
channels := make([]chan *nats.Msg, 0, len(subs))
|
||||
for sub := range subs {
|
||||
channels = append(channels, sub.ch)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
defer c.mu.Lock()
|
||||
for _, ch := range channels {
|
||||
select {
|
||||
case ch <- msg:
|
||||
default:
|
||||
log.Printf("Slow consumer %s, dropping message", msg.Subject)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) Close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, subs := range c.subscriptions {
|
||||
for sub := range subs {
|
||||
sub.Unsubscribe() // nolint
|
||||
}
|
||||
}
|
||||
|
||||
c.subscriptions = nil
|
||||
c.stopping = true
|
||||
c.incoming.Init()
|
||||
c.wakeup.Signal()
|
||||
}
|
||||
|
||||
type loopbackNatsSubscription struct {
|
||||
subject string
|
||||
client *LoopbackNatsClient
|
||||
ch chan *nats.Msg
|
||||
incoming []*nats.Msg
|
||||
cond sync.Cond
|
||||
quit bool
|
||||
subject string
|
||||
client *LoopbackNatsClient
|
||||
|
||||
ch chan *nats.Msg
|
||||
}
|
||||
|
||||
func (s *loopbackNatsSubscription) Unsubscribe() error {
|
||||
s.cond.L.Lock()
|
||||
if !s.quit {
|
||||
s.quit = true
|
||||
s.cond.Signal()
|
||||
}
|
||||
s.cond.L.Unlock()
|
||||
|
||||
s.client.unsubscribe(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *loopbackNatsSubscription) queue(msg *nats.Msg) {
|
||||
s.cond.L.Lock()
|
||||
s.incoming = append(s.incoming, msg)
|
||||
if len(s.incoming) == 1 {
|
||||
s.cond.Signal()
|
||||
}
|
||||
s.cond.L.Unlock()
|
||||
}
|
||||
|
||||
func (s *loopbackNatsSubscription) run() {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
for !s.quit {
|
||||
for !s.quit && len(s.incoming) == 0 {
|
||||
s.cond.Wait()
|
||||
}
|
||||
|
||||
for !s.quit && len(s.incoming) > 0 {
|
||||
msg := s.incoming[0]
|
||||
s.incoming = s.incoming[1:]
|
||||
s.cond.L.Unlock()
|
||||
// A "real" NATS server would take some time to process the request,
|
||||
// simulate this by sleeping a tiny bit.
|
||||
time.Sleep(time.Millisecond)
|
||||
s.ch <- msg
|
||||
s.cond.L.Lock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.subscribe(subject, ch)
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) {
|
||||
if strings.HasSuffix(subject, ".") || strings.Contains(subject, " ") {
|
||||
return nil, nats.ErrBadSubject
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.subscriptions == nil {
|
||||
return nil, nats.ErrConnectionClosed
|
||||
}
|
||||
|
@ -126,7 +125,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS
|
|||
client: c,
|
||||
ch: ch,
|
||||
}
|
||||
s.cond.L = &sync.Mutex{}
|
||||
subs, found := c.subscriptions[subject]
|
||||
if !found {
|
||||
subs = make(map[*loopbackNatsSubscription]bool)
|
||||
|
@ -134,7 +132,6 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS
|
|||
}
|
||||
subs[s] = true
|
||||
|
||||
go s.run()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -161,18 +158,15 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error
|
|||
return nats.ErrConnectionClosed
|
||||
}
|
||||
|
||||
if subs, found := c.subscriptions[subject]; found {
|
||||
msg := &nats.Msg{
|
||||
Subject: subject,
|
||||
}
|
||||
var err error
|
||||
if msg.Data, err = json.Marshal(message); err != nil {
|
||||
return err
|
||||
}
|
||||
for s := range subs {
|
||||
s.queue(msg)
|
||||
}
|
||||
msg := &nats.Msg{
|
||||
Subject: subject,
|
||||
}
|
||||
var err error
|
||||
if msg.Data, err = json.Marshal(message); err != nil {
|
||||
return err
|
||||
}
|
||||
c.incoming.PushBack(msg)
|
||||
c.wakeup.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -182,16 +176,18 @@ func (c *LoopbackNatsClient) PublishNats(subject string, message *NatsMessage) e
|
|||
|
||||
func (c *LoopbackNatsClient) PublishMessage(subject string, message *ServerMessage) error {
|
||||
msg := &NatsMessage{
|
||||
Type: "message",
|
||||
Message: message,
|
||||
SendTime: time.Now(),
|
||||
Type: "message",
|
||||
Message: message,
|
||||
}
|
||||
return c.PublishNats(subject, msg)
|
||||
}
|
||||
|
||||
func (c *LoopbackNatsClient) PublishBackendServerRoomRequest(subject string, message *BackendServerRoomRequest) error {
|
||||
msg := &NatsMessage{
|
||||
Type: "room",
|
||||
Room: message,
|
||||
SendTime: time.Now(),
|
||||
Type: "room",
|
||||
Room: message,
|
||||
}
|
||||
return c.PublishNats(subject, msg)
|
||||
}
|
||||
|
|
|
@ -48,17 +48,20 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t
|
|||
}
|
||||
}
|
||||
|
||||
func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient {
|
||||
func CreateLoopbackNatsClientForTest(t *testing.T) (NatsClient, func()) {
|
||||
result, err := NewLoopbackNatsClient()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return result
|
||||
return result, func() {
|
||||
result.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoopbackNatsClient_Subscribe(t *testing.T) {
|
||||
ensureNoGoroutinesLeak(t, func() {
|
||||
client := CreateLoopbackNatsClientForTest(t)
|
||||
client, shutdown := CreateLoopbackNatsClientForTest(t)
|
||||
defer shutdown()
|
||||
|
||||
testNatsClient_Subscribe(t, client)
|
||||
})
|
||||
|
@ -66,7 +69,8 @@ func TestLoopbackNatsClient_Subscribe(t *testing.T) {
|
|||
|
||||
func TestLoopbackClient_PublishAfterClose(t *testing.T) {
|
||||
ensureNoGoroutinesLeak(t, func() {
|
||||
client := CreateLoopbackNatsClientForTest(t)
|
||||
client, shutdown := CreateLoopbackNatsClientForTest(t)
|
||||
defer shutdown()
|
||||
|
||||
testNatsClient_PublishAfterClose(t, client)
|
||||
})
|
||||
|
@ -74,7 +78,8 @@ func TestLoopbackClient_PublishAfterClose(t *testing.T) {
|
|||
|
||||
func TestLoopbackClient_SubscribeAfterClose(t *testing.T) {
|
||||
ensureNoGoroutinesLeak(t, func() {
|
||||
client := CreateLoopbackNatsClientForTest(t)
|
||||
client, shutdown := CreateLoopbackNatsClientForTest(t)
|
||||
defer shutdown()
|
||||
|
||||
testNatsClient_SubscribeAfterClose(t, client)
|
||||
})
|
||||
|
@ -82,7 +87,8 @@ func TestLoopbackClient_SubscribeAfterClose(t *testing.T) {
|
|||
|
||||
func TestLoopbackClient_BadSubjects(t *testing.T) {
|
||||
ensureNoGoroutinesLeak(t, func() {
|
||||
client := CreateLoopbackNatsClientForTest(t)
|
||||
client, shutdown := CreateLoopbackNatsClientForTest(t)
|
||||
defer shutdown()
|
||||
|
||||
testNatsClient_BadSubjects(t, client)
|
||||
})
|
||||
|
|
|
@ -90,7 +90,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
|
|||
}
|
||||
|
||||
// Allow NATS goroutines to process messages.
|
||||
time.Sleep(time.Millisecond)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
<-ch
|
||||
|
||||
|
|
|
@ -227,7 +227,13 @@ func (c *TestClient) CloseWithBye() {
|
|||
}
|
||||
|
||||
func (c *TestClient) Close() {
|
||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{}) // nolint
|
||||
if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err == websocket.ErrCloseSent {
|
||||
// Already closed
|
||||
return
|
||||
}
|
||||
|
||||
// Wait a bit for close message to be processed.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
c.conn.Close()
|
||||
|
||||
// Drain any entries in the channels to terminate the read goroutine.
|
||||
|
|
Loading…
Reference in New Issue