diff --git a/go.mod b/go.mod index 7e4102d..4bc0ec9 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/gorilla/securecookie v1.1.1 github.com/gorilla/websocket v1.4.2 github.com/mailru/easyjson v0.7.7 - github.com/nats-io/nats-server/v2 v2.2.1 // indirect + github.com/nats-io/nats-server/v2 v2.2.1 github.com/nats-io/nats.go v1.10.1-0.20210330225420-a0b1f60162f8 github.com/notedit/janus-go v0.0.0-20200517101215-10eb8b95d1a0 github.com/oschwald/maxminddb-golang v1.8.0 diff --git a/natsclient.go b/natsclient.go index 815ba23..162e2e0 100644 --- a/natsclient.go +++ b/natsclient.go @@ -52,6 +52,8 @@ type NatsSubscription interface { } type NatsClient interface { + Close() + Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) Request(subject string, data []byte, timeout time.Duration) (*nats.Msg, error) @@ -120,6 +122,10 @@ func NewNatsClient(url string) (NatsClient, error) { return client, nil } +func (c *natsClient) Close() { + c.conn.Close() +} + func (c *natsClient) onClosed(conn *nats.Conn) { log.Println("NATS client closed", conn.LastError()) } diff --git a/natsclient_loopback.go b/natsclient_loopback.go index 78f660d..5e4e08b 100644 --- a/natsclient_loopback.go +++ b/natsclient_loopback.go @@ -45,6 +45,19 @@ func NewLoopbackNatsClient() (NatsClient, error) { }, nil } +func (c *LoopbackNatsClient) Close() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, subs := range c.subscriptions { + for sub := range subs { + sub.Unsubscribe() // nolint + } + } + + c.subscriptions = nil +} + type loopbackNatsSubscription struct { subject string client *LoopbackNatsClient @@ -105,6 +118,10 @@ func (c *LoopbackNatsClient) subscribe(subject string, ch chan *nats.Msg) (NatsS return nil, nats.ErrBadSubject } + if c.subscriptions == nil { + return nil, nats.ErrConnectionClosed + } + s := &loopbackNatsSubscription{ subject: subject, client: c, @@ -141,21 +158,15 @@ func (c *LoopbackNatsClient) Request(subject string, data []byte, timeout time.D c.mu.Lock() defer c.mu.Unlock() + if c.subscriptions == nil { + return nil, nats.ErrConnectionClosed + } + var response *nats.Msg var err error subs, found := c.subscriptions[subject] if !found { - c.mu.Unlock() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - err = nats.ErrTimeout - } else { - err = ctx.Err() - } - c.mu.Lock() - return nil, err + return nil, nats.ErrNoResponders } replyId := c.replyId @@ -212,6 +223,10 @@ func (c *LoopbackNatsClient) Publish(subject string, message interface{}) error c.mu.Lock() defer c.mu.Unlock() + if c.subscriptions == nil { + return nats.ErrConnectionClosed + } + if subs, found := c.subscriptions[subject]; found { msg := &nats.Msg{ Subject: subject, diff --git a/natsclient_loopback_test.go b/natsclient_loopback_test.go index b98fcc1..7fec2d1 100644 --- a/natsclient_loopback_test.go +++ b/natsclient_loopback_test.go @@ -22,14 +22,9 @@ package signaling import ( - "bytes" "context" - "runtime" - "sync/atomic" "testing" "time" - - "github.com/nats-io/nats.go" ) func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *testing.T) { @@ -62,167 +57,33 @@ func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { } func TestLoopbackNatsClient_Subscribe(t *testing.T) { - // Give time for things to settle before capturing the number of - // go routines - time.Sleep(500 * time.Millisecond) + ensureNoGoroutinesLeak(t, func() { + client := CreateLoopbackNatsClientForTest(t) - base := runtime.NumGoroutine() - - client := CreateLoopbackNatsClientForTest(t) - dest := make(chan *nats.Msg) - sub, err := client.Subscribe("foo", dest) - if err != nil { - t.Fatal(err) - } - ch := make(chan bool) - - received := int32(0) - max := int32(20) - quit := make(chan bool) - go func() { - for { - select { - case <-dest: - total := atomic.AddInt32(&received, 1) - if total == max { - err := sub.Unsubscribe() - if err != nil { - t.Errorf("Unsubscribe failed with err: %s", err) - return - } - ch <- true - } - case <-quit: - return - } - } - }() - for i := int32(0); i < max; i++ { - if err := client.Publish("foo", []byte("hello")); err != nil { - t.Error(err) - } - } - <-ch - - r := atomic.LoadInt32(&received) - if r != max { - t.Fatalf("Received wrong # of messages: %d vs %d", r, max) - } - quit <- true - - // Give time for things to settle before capturing the number of - // go routines - time.Sleep(500 * time.Millisecond) - - delta := (runtime.NumGoroutine() - base) - if delta > 0 { - t.Fatalf("%d Go routines still exist post Close()", delta) - } + testNatsClient_Subscribe(t, client) + }) } func TestLoopbackNatsClient_Request(t *testing.T) { - // Give time for things to settle before capturing the number of - // go routines - time.Sleep(500 * time.Millisecond) + ensureNoGoroutinesLeak(t, func() { + client := CreateLoopbackNatsClientForTest(t) - base := runtime.NumGoroutine() - - client := CreateLoopbackNatsClientForTest(t) - dest := make(chan *nats.Msg) - sub, err := client.Subscribe("foo", dest) - if err != nil { - t.Fatal(err) - } - - go func() { - msg := <-dest - if err := client.Publish(msg.Reply, []byte("world")); err != nil { - t.Error(err) - return - } - if err := sub.Unsubscribe(); err != nil { - t.Error("Unsubscribe failed with err:", err) - return - } - }() - reply, err := client.Request("foo", []byte("hello"), 1*time.Second) - if err != nil { - t.Fatal(err) - } - - var response []byte - if err := client.Decode(reply, &response); err != nil { - t.Fatal(err) - } - if !bytes.Equal(response, []byte("world")) { - t.Fatalf("expected 'world', got '%s'", string(reply.Data)) - } - - // Give time for things to settle before capturing the number of - // go routines - time.Sleep(500 * time.Millisecond) - - delta := (runtime.NumGoroutine() - base) - if delta > 0 { - t.Fatalf("%d Go routines still exist post Close()", delta) - } + testNatsClient_Request(t, client) + }) } func TestLoopbackNatsClient_RequestTimeout(t *testing.T) { - // Give time for things to settle before capturing the number of - // go routines - time.Sleep(500 * time.Millisecond) + ensureNoGoroutinesLeak(t, func() { + client := CreateLoopbackNatsClientForTest(t) - base := runtime.NumGoroutine() - - client := CreateLoopbackNatsClientForTest(t) - dest := make(chan *nats.Msg) - sub, err := client.Subscribe("foo", dest) - if err != nil { - t.Fatal(err) - } - - go func() { - msg := <-dest - time.Sleep(200 * time.Millisecond) - if err := client.Publish(msg.Reply, []byte("world")); err != nil { - t.Error(err) - return - } - if err := sub.Unsubscribe(); err != nil { - t.Error("Unsubscribe failed with err:", err) - return - } - }() - reply, err := client.Request("foo", []byte("hello"), 100*time.Millisecond) - if err == nil { - t.Fatalf("Request should have timed out, reeived %+v", reply) - } else if err != nats.ErrTimeout { - t.Fatalf("Request should have timed out, received error %s", err) - } - - // Give time for things to settle before capturing the number of - // go routines - time.Sleep(500 * time.Millisecond) - - delta := (runtime.NumGoroutine() - base) - if delta > 0 { - t.Fatalf("%d Go routines still exist post Close()", delta) - } + testNatsClient_RequestTimeout(t, client) + }) } -func TestLoopbackNatsClient_RequestTimeoutNoReply(t *testing.T) { - client := CreateLoopbackNatsClientForTest(t) - timeout := 100 * time.Millisecond - start := time.Now() - reply, err := client.Request("foo", []byte("hello"), timeout) - end := time.Now() - if err == nil { - t.Fatalf("Request should have timed out, reeived %+v", reply) - } else if err != nats.ErrTimeout { - t.Fatalf("Request should have timed out, received error %s", err) - } - if end.Sub(start) < timeout { - t.Errorf("Expected a delay of %s but had %s", timeout, end.Sub(start)) - } +func TestLoopbackNatsClient_RequestNoReply(t *testing.T) { + ensureNoGoroutinesLeak(t, func() { + client := CreateLoopbackNatsClientForTest(t) + + testNatsClient_RequestNoReply(t, client) + }) } diff --git a/natsclient_test.go b/natsclient_test.go new file mode 100644 index 0000000..79cfbd8 --- /dev/null +++ b/natsclient_test.go @@ -0,0 +1,214 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2021 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/nats-io/nats.go" + + natsserver "github.com/nats-io/nats-server/v2/test" +) + +func startLocalNatsServer() (string, func()) { + opts := natsserver.DefaultTestOptions + opts.Port = -1 + opts.Cluster.Name = "testing" + srv := natsserver.RunServer(&opts) + shutdown := func() { + srv.Shutdown() + srv.WaitForShutdown() + } + return srv.ClientURL(), shutdown +} + +func CreateLocalNatsClientForTest(t *testing.T) (NatsClient, func()) { + url, shutdown := startLocalNatsServer() + result, err := NewNatsClient(url) + if err != nil { + t.Fatal(err) + } + return result, func() { + result.Close() + shutdown() + } +} + +func testNatsClient_Subscribe(t *testing.T, client NatsClient) { + dest := make(chan *nats.Msg) + sub, err := client.Subscribe("foo", dest) + if err != nil { + t.Fatal(err) + } + ch := make(chan bool) + + received := int32(0) + max := int32(20) + quit := make(chan bool) + go func() { + for { + select { + case <-dest: + total := atomic.AddInt32(&received, 1) + if total == max { + err := sub.Unsubscribe() + if err != nil { + t.Errorf("Unsubscribe failed with err: %s", err) + return + } + ch <- true + } + case <-quit: + return + } + } + }() + for i := int32(0); i < max; i++ { + if err := client.Publish("foo", []byte("hello")); err != nil { + t.Error(err) + } + + // Allow NATS goroutines to process messages. + time.Sleep(time.Millisecond) + } + <-ch + + r := atomic.LoadInt32(&received) + if r != max { + t.Fatalf("Received wrong # of messages: %d vs %d", r, max) + } + quit <- true +} + +func TestNatsClient_Subscribe(t *testing.T) { + ensureNoGoroutinesLeak(t, func() { + client, shutdown := CreateLocalNatsClientForTest(t) + defer shutdown() + + testNatsClient_Subscribe(t, client) + }) +} + +func testNatsClient_Request(t *testing.T, client NatsClient) { + dest := make(chan *nats.Msg) + sub, err := client.Subscribe("foo", dest) + if err != nil { + t.Fatal(err) + } + + go func() { + msg := <-dest + if err := client.Publish(msg.Reply, "world"); err != nil { + t.Error(err) + return + } + if err := sub.Unsubscribe(); err != nil { + t.Error("Unsubscribe failed with err:", err) + return + } + }() + reply, err := client.Request("foo", []byte("hello"), 30*time.Second) + if err != nil { + t.Fatal(err) + } + + var response string + if err := client.Decode(reply, &response); err != nil { + t.Fatal(err) + } + if response != "world" { + t.Fatalf("expected 'world', got '%s'", string(reply.Data)) + } +} + +func TestNatsClient_Request(t *testing.T) { + ensureNoGoroutinesLeak(t, func() { + client, shutdown := CreateLocalNatsClientForTest(t) + defer shutdown() + + testNatsClient_Request(t, client) + }) +} + +func testNatsClient_RequestTimeout(t *testing.T, client NatsClient) { + dest := make(chan *nats.Msg) + sub, err := client.Subscribe("foo", dest) + if err != nil { + t.Fatal(err) + } + + go func() { + msg := <-dest + time.Sleep(200 * time.Millisecond) + if err := client.Publish(msg.Reply, []byte("world")); err != nil { + if err != nats.ErrConnectionClosed { + t.Error(err) + } + return + } + if err := sub.Unsubscribe(); err != nil { + t.Error("Unsubscribe failed with err:", err) + return + } + }() + reply, err := client.Request("foo", []byte("hello"), 100*time.Millisecond) + if err == nil { + t.Fatalf("Request should have timed out, reeived %+v", reply) + } else if err != nats.ErrTimeout { + t.Fatalf("Request should have timed out, received error %s", err) + } +} + +func TestNatsClient_RequestTimeout(t *testing.T) { + ensureNoGoroutinesLeak(t, func() { + client, shutdown := CreateLocalNatsClientForTest(t) + defer shutdown() + + testNatsClient_RequestTimeout(t, client) + }) +} + +func testNatsClient_RequestNoReply(t *testing.T, client NatsClient) { + timeout := 100 * time.Millisecond + start := time.Now() + reply, err := client.Request("foo", []byte("hello"), timeout) + end := time.Now() + if err == nil { + t.Fatalf("Request should have failed without responsers, reeived %+v", reply) + } else if err != nats.ErrNoResponders { + t.Fatalf("Request should have failed without responsers, received error %s", err) + } + if end.Sub(start) >= timeout { + t.Errorf("Should have failed immediately but took %s", end.Sub(start)) + } +} + +func TestNatsClient_RequestNoReply(t *testing.T) { + ensureNoGoroutinesLeak(t, func() { + client, shutdown := CreateLocalNatsClientForTest(t) + defer shutdown() + + testNatsClient_RequestNoReply(t, client) + }) +} diff --git a/testutils_test.go b/testutils_test.go new file mode 100644 index 0000000..f099948 --- /dev/null +++ b/testutils_test.go @@ -0,0 +1,57 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2021 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "os" + "runtime/pprof" + "testing" + "time" +) + +func ensureNoGoroutinesLeak(t *testing.T, f func()) { + // Give time for things to settle before capturing the number of + // go routines + time.Sleep(500 * time.Millisecond) + before := pprof.Lookup("goroutine") + + f() + + var after *pprof.Profile + // Give time for things to settle before capturing the number of + // go routines + timeout := time.Now().Add(time.Second) + for time.Now().Before(timeout) { + after = pprof.Lookup("goroutine") + if after.Count() == before.Count() { + break + } + } + + if after.Count() != before.Count() { + os.Stderr.WriteString("Before:\n") + before.WriteTo(os.Stderr, 1) // nolint + os.Stderr.WriteString("After:\n") + after.WriteTo(os.Stderr, 1) // nolint + t.Fatalf("Number of Go routines has changed in %s", t.Name()) + } +}