Wait for NATS client to be closed.

This commit is contained in:
Joachim Bauch 2025-11-20 10:14:48 +01:00
commit bfcabaa2fc
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
10 changed files with 73 additions and 18 deletions

View file

@ -43,7 +43,7 @@ type AsyncSessionEventListener interface {
}
type AsyncEvents interface {
Close()
Close(ctx context.Context) error
RegisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener) error
UnregisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener)

View file

@ -22,6 +22,7 @@
package signaling
import (
"context"
"fmt"
"sync"
"time"
@ -281,7 +282,7 @@ func (e *asyncEventsNats) GetServerInfoNats() *BackendServerInfoNats {
return nats
}
func (e *asyncEventsNats) Close() {
func (e *asyncEventsNats) Close(ctx context.Context) error {
e.mu.Lock()
defer e.mu.Unlock()
var wg sync.WaitGroup
@ -320,7 +321,7 @@ func (e *asyncEventsNats) Close() {
e.userSubscriptions = make(map[string]*asyncUserSubscriberNats)
e.sessionSubscriptions = make(map[string]*asyncSessionSubscriberNats)
wg.Wait()
e.client.Close()
return e.client.Close(ctx)
}
func (e *asyncEventsNats) RegisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener) error {

View file

@ -25,7 +25,9 @@ import (
"context"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -44,7 +46,9 @@ func getAsyncEventsForTest(t *testing.T) AsyncEvents {
events = getLoopbackAsyncEventsForTest(t)
}
t.Cleanup(func() {
events.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
assert.NoError(t, events.Close(ctx))
})
return events
}

View file

@ -125,6 +125,7 @@ func CreateBackendServerWithClusteringForTest(t *testing.T) (*BackendServer, *Ba
func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *goconf.ConfigFile, config2 *goconf.ConfigFile) (*BackendServer, *BackendServer, *Hub, *Hub, *httptest.Server, *httptest.Server) {
require := require.New(t)
assert := assert.New(t)
r1 := mux.NewRouter()
registerBackendHandler(t, r1)
@ -166,7 +167,9 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
events1, err := NewAsyncEvents(ctx, nats.ClientURL())
require.NoError(err)
t.Cleanup(func() {
events1.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
assert.NoError(events1.Close(ctx))
})
client1, _ := NewGrpcClientsForTest(t, addr2)
hub1, err := NewHub(ctx, config1, events1, grpcServer1, client1, nil, r1, "no-version")
@ -189,7 +192,9 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
events2, err := NewAsyncEvents(ctx, nats.ClientURL())
require.NoError(err)
t.Cleanup(func() {
events2.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
assert.NoError(events2.Close(ctx))
})
client2, _ := NewGrpcClientsForTest(t, addr1)
hub2, err := NewHub(ctx, config2, events2, grpcServer2, client2, nil, r2, "no-version")

View file

@ -204,6 +204,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
logger := NewLoggerForTest(t)
ctx := NewLoggerContext(t.Context(), logger)
require := require.New(t)
assert := assert.New(t)
r1 := mux.NewRouter()
registerBackendHandler(t, r1)
@ -238,7 +239,9 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
events1, err := NewAsyncEvents(ctx, nats1.ClientURL())
require.NoError(err)
t.Cleanup(func() {
events1.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
assert.NoError(events1.Close(ctx))
})
config1, err := getConfigFunc(server1)
require.NoError(err)
@ -250,7 +253,9 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
events2, err := NewAsyncEvents(ctx, nats2.ClientURL())
require.NoError(err)
t.Cleanup(func() {
events2.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
assert.NoError(events2.Close(ctx))
})
config2, err := getConfigFunc(server2)
require.NoError(err)

View file

@ -47,7 +47,7 @@ type NatsSubscription interface {
}
type NatsClient interface {
Close()
Close(ctx context.Context) error
Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error)
Publish(subject string, message any) error
@ -65,6 +65,7 @@ func GetEncodedSubject(prefix string, suffix string) string {
type natsClient struct {
logger Logger
conn *nats.Conn
closed chan struct{}
}
func NewNatsClient(ctx context.Context, url string, options ...nats.Option) (NatsClient, error) {
@ -85,6 +86,7 @@ func NewNatsClient(ctx context.Context, url string, options ...nats.Option) (Nat
client := &natsClient{
logger: logger,
closed: make(chan struct{}),
}
options = append([]nats.Option{
@ -112,12 +114,23 @@ func NewNatsClient(ctx context.Context, url string, options ...nats.Option) (Nat
return client, nil
}
func (c *natsClient) Close() {
func (c *natsClient) Close(ctx context.Context) error {
c.conn.Close()
select {
case <-c.closed:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *natsClient) onClosed(conn *nats.Conn) {
c.logger.Println("NATS client closed", conn.LastError())
if err := conn.LastError(); err != nil {
c.logger.Printf("NATS client closed, last error %s", conn.LastError())
} else {
c.logger.Println("NATS client closed")
}
close(c.closed)
}
func (c *natsClient) onDisconnected(conn *nats.Conn) {

View file

@ -23,6 +23,7 @@ package signaling
import (
"container/list"
"context"
"encoding/json"
"strings"
"sync"
@ -33,7 +34,9 @@ import (
type LoopbackNatsClient struct {
logger Logger
mu sync.Mutex
mu sync.Mutex
closed chan struct{}
// +checklocks:mu
subscriptions map[string]map[*loopbackNatsSubscription]bool
@ -46,6 +49,7 @@ type LoopbackNatsClient struct {
func NewLoopbackNatsClient(logger Logger) (NatsClient, error) {
client := &LoopbackNatsClient{
logger: logger,
closed: make(chan struct{}),
subscriptions: make(map[string]map[*loopbackNatsSubscription]bool),
}
@ -55,6 +59,8 @@ func NewLoopbackNatsClient(logger Logger) (NatsClient, error) {
}
func (c *LoopbackNatsClient) processMessages() {
defer close(c.closed)
c.mu.Lock()
defer c.mu.Unlock()
for {
@ -93,7 +99,7 @@ func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) {
}
}
func (c *LoopbackNatsClient) Close() {
func (c *LoopbackNatsClient) doClose() {
c.mu.Lock()
defer c.mu.Unlock()
@ -102,6 +108,16 @@ func (c *LoopbackNatsClient) Close() {
c.wakeup.Signal()
}
func (c *LoopbackNatsClient) Close(ctx context.Context) error {
c.doClose()
select {
case <-c.closed:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
type loopbackNatsSubscription struct {
subject string
client *LoopbackNatsClient

View file

@ -56,7 +56,9 @@ func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient {
result, err := NewLoopbackNatsClient(logger)
require.NoError(t, err)
t.Cleanup(func() {
result.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
assert.NoError(t, result.Close(ctx))
})
return result
}

View file

@ -22,6 +22,7 @@
package signaling
import (
"context"
"sync/atomic"
"testing"
"time"
@ -60,7 +61,9 @@ func CreateLocalNatsClientForTest(t *testing.T, options ...nats.Option) (*server
result, err := NewNatsClient(ctx, server.ClientURL(), options...)
require.NoError(t, err)
t.Cleanup(func() {
result.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
assert.NoError(t, result.Close(ctx))
})
return server, port, result
}
@ -116,7 +119,7 @@ func TestNatsClient_Subscribe(t *testing.T) {
}
func testNatsClient_PublishAfterClose(t *testing.T, client NatsClient) {
client.Close()
assert.NoError(t, client.Close(t.Context()))
assert.ErrorIs(t, client.Publish("foo", "bar"), nats.ErrConnectionClosed)
}
@ -130,7 +133,7 @@ func TestNatsClient_PublishAfterClose(t *testing.T) {
}
func testNatsClient_SubscribeAfterClose(t *testing.T, client NatsClient) {
client.Close()
assert.NoError(t, client.Close(t.Context()))
ch := make(chan *nats.Msg)
_, err := client.Subscribe("foo", ch)

View file

@ -184,7 +184,13 @@ func main() {
if err != nil {
logger.Fatal("Could not create async events client: ", err)
}
defer events.Close()
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := events.Close(ctx); err != nil {
logger.Printf("Error closing events handler: %s", err)
}
}()
dnsMonitor, err := signaling.NewDnsMonitor(logger, dnsMonitorInterval)
if err != nil {