diff --git a/async_events.go b/async_events.go index 6598cd3..bbb8aab 100644 --- a/async_events.go +++ b/async_events.go @@ -43,7 +43,7 @@ type AsyncSessionEventListener interface { } type AsyncEvents interface { - Close() + Close(ctx context.Context) error RegisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener) error UnregisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener) diff --git a/async_events_nats.go b/async_events_nats.go index ab9fe13..d0ce29b 100644 --- a/async_events_nats.go +++ b/async_events_nats.go @@ -22,6 +22,7 @@ package signaling import ( + "context" "fmt" "sync" "time" @@ -281,7 +282,7 @@ func (e *asyncEventsNats) GetServerInfoNats() *BackendServerInfoNats { return nats } -func (e *asyncEventsNats) Close() { +func (e *asyncEventsNats) Close(ctx context.Context) error { e.mu.Lock() defer e.mu.Unlock() var wg sync.WaitGroup @@ -320,7 +321,7 @@ func (e *asyncEventsNats) Close() { e.userSubscriptions = make(map[string]*asyncUserSubscriberNats) e.sessionSubscriptions = make(map[string]*asyncSessionSubscriberNats) wg.Wait() - e.client.Close() + return e.client.Close(ctx) } func (e *asyncEventsNats) RegisterBackendRoomListener(roomId string, backend *Backend, listener AsyncBackendRoomEventListener) error { diff --git a/async_events_test.go b/async_events_test.go index 02d6145..ac14dfb 100644 --- a/async_events_test.go +++ b/async_events_test.go @@ -25,7 +25,9 @@ import ( "context" "strings" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -44,7 +46,9 @@ func getAsyncEventsForTest(t *testing.T) AsyncEvents { events = getLoopbackAsyncEventsForTest(t) } t.Cleanup(func() { - events.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + assert.NoError(t, events.Close(ctx)) }) return events } diff --git a/backend_server_test.go b/backend_server_test.go index 460bb7d..7a89a82 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -125,6 +125,7 @@ func CreateBackendServerWithClusteringForTest(t *testing.T) (*BackendServer, *Ba func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *goconf.ConfigFile, config2 *goconf.ConfigFile) (*BackendServer, *BackendServer, *Hub, *Hub, *httptest.Server, *httptest.Server) { require := require.New(t) + assert := assert.New(t) r1 := mux.NewRouter() registerBackendHandler(t, r1) @@ -166,7 +167,9 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g events1, err := NewAsyncEvents(ctx, nats.ClientURL()) require.NoError(err) t.Cleanup(func() { - events1.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + assert.NoError(events1.Close(ctx)) }) client1, _ := NewGrpcClientsForTest(t, addr2) hub1, err := NewHub(ctx, config1, events1, grpcServer1, client1, nil, r1, "no-version") @@ -189,7 +192,9 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g events2, err := NewAsyncEvents(ctx, nats.ClientURL()) require.NoError(err) t.Cleanup(func() { - events2.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + assert.NoError(events2.Close(ctx)) }) client2, _ := NewGrpcClientsForTest(t, addr1) hub2, err := NewHub(ctx, config2, events2, grpcServer2, client2, nil, r2, "no-version") diff --git a/hub_test.go b/hub_test.go index 3cda700..fab8ad4 100644 --- a/hub_test.go +++ b/hub_test.go @@ -204,6 +204,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http logger := NewLoggerForTest(t) ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) + assert := assert.New(t) r1 := mux.NewRouter() registerBackendHandler(t, r1) @@ -238,7 +239,9 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http events1, err := NewAsyncEvents(ctx, nats1.ClientURL()) require.NoError(err) t.Cleanup(func() { - events1.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + assert.NoError(events1.Close(ctx)) }) config1, err := getConfigFunc(server1) require.NoError(err) @@ -250,7 +253,9 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http events2, err := NewAsyncEvents(ctx, nats2.ClientURL()) require.NoError(err) t.Cleanup(func() { - events2.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + assert.NoError(events2.Close(ctx)) }) config2, err := getConfigFunc(server2) require.NoError(err) diff --git a/natsclient.go b/natsclient.go index 1474155..10292e7 100644 --- a/natsclient.go +++ b/natsclient.go @@ -47,7 +47,7 @@ type NatsSubscription interface { } type NatsClient interface { - Close() + Close(ctx context.Context) error Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) Publish(subject string, message any) error @@ -65,6 +65,7 @@ func GetEncodedSubject(prefix string, suffix string) string { type natsClient struct { logger Logger conn *nats.Conn + closed chan struct{} } func NewNatsClient(ctx context.Context, url string, options ...nats.Option) (NatsClient, error) { @@ -85,6 +86,7 @@ func NewNatsClient(ctx context.Context, url string, options ...nats.Option) (Nat client := &natsClient{ logger: logger, + closed: make(chan struct{}), } options = append([]nats.Option{ @@ -112,12 +114,23 @@ func NewNatsClient(ctx context.Context, url string, options ...nats.Option) (Nat return client, nil } -func (c *natsClient) Close() { +func (c *natsClient) Close(ctx context.Context) error { c.conn.Close() + select { + case <-c.closed: + return nil + case <-ctx.Done(): + return ctx.Err() + } } func (c *natsClient) onClosed(conn *nats.Conn) { - c.logger.Println("NATS client closed", conn.LastError()) + if err := conn.LastError(); err != nil { + c.logger.Printf("NATS client closed, last error %s", conn.LastError()) + } else { + c.logger.Println("NATS client closed") + } + close(c.closed) } func (c *natsClient) onDisconnected(conn *nats.Conn) { diff --git a/natsclient_loopback.go b/natsclient_loopback.go index 1421d9a..a478beb 100644 --- a/natsclient_loopback.go +++ b/natsclient_loopback.go @@ -23,6 +23,7 @@ package signaling import ( "container/list" + "context" "encoding/json" "strings" "sync" @@ -33,7 +34,9 @@ import ( type LoopbackNatsClient struct { logger Logger - mu sync.Mutex + mu sync.Mutex + closed chan struct{} + // +checklocks:mu subscriptions map[string]map[*loopbackNatsSubscription]bool @@ -46,6 +49,7 @@ type LoopbackNatsClient struct { func NewLoopbackNatsClient(logger Logger) (NatsClient, error) { client := &LoopbackNatsClient{ logger: logger, + closed: make(chan struct{}), subscriptions: make(map[string]map[*loopbackNatsSubscription]bool), } @@ -55,6 +59,8 @@ func NewLoopbackNatsClient(logger Logger) (NatsClient, error) { } func (c *LoopbackNatsClient) processMessages() { + defer close(c.closed) + c.mu.Lock() defer c.mu.Unlock() for { @@ -93,7 +99,7 @@ func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) { } } -func (c *LoopbackNatsClient) Close() { +func (c *LoopbackNatsClient) doClose() { c.mu.Lock() defer c.mu.Unlock() @@ -102,6 +108,16 @@ func (c *LoopbackNatsClient) Close() { c.wakeup.Signal() } +func (c *LoopbackNatsClient) Close(ctx context.Context) error { + c.doClose() + select { + case <-c.closed: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + type loopbackNatsSubscription struct { subject string client *LoopbackNatsClient diff --git a/natsclient_loopback_test.go b/natsclient_loopback_test.go index 6cf6ed7..9a299ca 100644 --- a/natsclient_loopback_test.go +++ b/natsclient_loopback_test.go @@ -56,7 +56,9 @@ func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { result, err := NewLoopbackNatsClient(logger) require.NoError(t, err) t.Cleanup(func() { - result.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + assert.NoError(t, result.Close(ctx)) }) return result } diff --git a/natsclient_test.go b/natsclient_test.go index cc3b22b..766e123 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -22,6 +22,7 @@ package signaling import ( + "context" "sync/atomic" "testing" "time" @@ -60,7 +61,9 @@ func CreateLocalNatsClientForTest(t *testing.T, options ...nats.Option) (*server result, err := NewNatsClient(ctx, server.ClientURL(), options...) require.NoError(t, err) t.Cleanup(func() { - result.Close() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + assert.NoError(t, result.Close(ctx)) }) return server, port, result } @@ -116,7 +119,7 @@ func TestNatsClient_Subscribe(t *testing.T) { } func testNatsClient_PublishAfterClose(t *testing.T, client NatsClient) { - client.Close() + assert.NoError(t, client.Close(t.Context())) assert.ErrorIs(t, client.Publish("foo", "bar"), nats.ErrConnectionClosed) } @@ -130,7 +133,7 @@ func TestNatsClient_PublishAfterClose(t *testing.T) { } func testNatsClient_SubscribeAfterClose(t *testing.T, client NatsClient) { - client.Close() + assert.NoError(t, client.Close(t.Context())) ch := make(chan *nats.Msg) _, err := client.Subscribe("foo", ch) diff --git a/server/main.go b/server/main.go index 444d2ca..fd34249 100644 --- a/server/main.go +++ b/server/main.go @@ -184,7 +184,13 @@ func main() { if err != nil { logger.Fatal("Could not create async events client: ", err) } - defer events.Close() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := events.Close(ctx); err != nil { + logger.Printf("Error closing events handler: %s", err) + } + }() dnsMonitor, err := signaling.NewDnsMonitor(logger, dnsMonitorInterval) if err != nil {