From 1232bfb3b3030b9410084a890c963b47b994c40c Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 5 Mar 2025 16:35:50 +0100 Subject: [PATCH] nats: Reconnect client indefinitely. --- async_events_test.go | 4 +-- backend_server_test.go | 6 ++--- hub_test.go | 11 ++++---- natsclient.go | 9 ++++--- natsclient_test.go | 60 ++++++++++++++++++++++++++++++++++-------- 5 files changed, 66 insertions(+), 24 deletions(-) diff --git a/async_events_test.go b/async_events_test.go index b72a30a..3ded5a0 100644 --- a/async_events_test.go +++ b/async_events_test.go @@ -50,8 +50,8 @@ func getAsyncEventsForTest(t *testing.T) AsyncEvents { } func getRealAsyncEventsForTest(t *testing.T) AsyncEvents { - url := startLocalNatsServer(t) - events, err := NewAsyncEvents(url) + server, _ := startLocalNatsServer(t) + events, err := NewAsyncEvents(server.ClientURL()) if err != nil { require.NoError(t, err) } diff --git a/backend_server_test.go b/backend_server_test.go index 858b7f3..8a84805 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -137,7 +137,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g server2.Close() }) - nats := startLocalNatsServer(t) + nats, _ := startLocalNatsServer(t) grpcServer1, addr1 := NewGrpcServerForTest(t) grpcServer2, addr2 := NewGrpcServerForTest(t) @@ -156,7 +156,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g config1.AddOption("clients", "internalsecret", string(testInternalSecret)) config1.AddOption("geoip", "url", "none") - events1, err := NewAsyncEvents(nats) + events1, err := NewAsyncEvents(nats.ClientURL()) require.NoError(err) t.Cleanup(func() { events1.Close() @@ -179,7 +179,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g config2.AddOption("sessions", "blockkey", "09876543210987654321098765432109") config2.AddOption("clients", "internalsecret", string(testInternalSecret)) config2.AddOption("geoip", "url", "none") - events2, err := NewAsyncEvents(nats) + events2, err := NewAsyncEvents(nats.ClientURL()) require.NoError(err) t.Cleanup(func() { events2.Close() diff --git a/hub_test.go b/hub_test.go index 0d4bf48..cf191e4 100644 --- a/hub_test.go +++ b/hub_test.go @@ -47,6 +47,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/gorilla/mux" "github.com/gorilla/websocket" + "github.com/nats-io/nats-server/v2/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -190,10 +191,10 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http server2.Close() }) - nats1 := startLocalNatsServer(t) - var nats2 string + nats1, _ := startLocalNatsServer(t) + var nats2 *server.Server if strings.Contains(t.Name(), "Federation") { - nats2 = startLocalNatsServer(t) + nats2, _ = startLocalNatsServer(t) } else { nats2 = nats1 } @@ -205,7 +206,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http addr1, addr2 = addr2, addr1 } - events1, err := NewAsyncEvents(nats1) + events1, err := NewAsyncEvents(nats1.ClientURL()) require.NoError(err) t.Cleanup(func() { events1.Close() @@ -217,7 +218,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http require.NoError(err) b1, err := NewBackendServer(config1, h1, "no-version") require.NoError(err) - events2, err := NewAsyncEvents(nats2) + events2, err := NewAsyncEvents(nats2.ClientURL()) require.NoError(err) t.Cleanup(func() { events2.Close() diff --git a/natsclient.go b/natsclient.go index 1893a8d..66e8a84 100644 --- a/natsclient.go +++ b/natsclient.go @@ -67,7 +67,7 @@ type natsClient struct { conn *nats.Conn } -func NewNatsClient(url string) (NatsClient, error) { +func NewNatsClient(url string, options ...nats.Option) (NatsClient, error) { if url == ":loopback:" { log.Printf("WARNING: events url %s is deprecated, please use %s instead", url, NatsLoopbackUrl) url = NatsLoopbackUrl @@ -84,10 +84,13 @@ func NewNatsClient(url string) (NatsClient, error) { client := &natsClient{} - client.conn, err = nats.Connect(url, + options = append([]nats.Option{ nats.ClosedHandler(client.onClosed), nats.DisconnectHandler(client.onDisconnected), - nats.ReconnectHandler(client.onReconnected)) + nats.ReconnectHandler(client.onReconnected), + nats.MaxReconnects(-1), + }, options...) + client.conn, err = nats.Connect(url, options...) ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) defer stop() diff --git a/natsclient_test.go b/natsclient_test.go index 430ef6d..362895b 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -30,29 +30,37 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/nats-io/nats-server/v2/server" natsserver "github.com/nats-io/nats-server/v2/test" ) -func startLocalNatsServer(t *testing.T) string { +func startLocalNatsServer(t *testing.T) (*server.Server, int) { + t.Helper() + return startLocalNatsServerPort(t, server.RANDOM_PORT) +} + +func startLocalNatsServerPort(t *testing.T, port int) (*server.Server, int) { + t.Helper() opts := natsserver.DefaultTestOptions - opts.Port = -1 + opts.Port = port opts.Cluster.Name = "testing" srv := natsserver.RunServer(&opts) t.Cleanup(func() { srv.Shutdown() srv.WaitForShutdown() }) - return srv.ClientURL() + return srv, opts.Port } -func CreateLocalNatsClientForTest(t *testing.T) NatsClient { - url := startLocalNatsServer(t) - result, err := NewNatsClient(url) +func CreateLocalNatsClientForTest(t *testing.T, options ...nats.Option) (*server.Server, int, NatsClient) { + t.Helper() + server, port := startLocalNatsServer(t) + result, err := NewNatsClient(server.ClientURL(), options...) require.NoError(t, err) t.Cleanup(func() { result.Close() }) - return result + return server, port, result } func testNatsClient_Subscribe(t *testing.T, client NatsClient) { @@ -100,7 +108,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { func TestNatsClient_Subscribe(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { - client := CreateLocalNatsClientForTest(t) + _, _, client := CreateLocalNatsClientForTest(t) testNatsClient_Subscribe(t, client) }) @@ -115,7 +123,7 @@ func testNatsClient_PublishAfterClose(t *testing.T, client NatsClient) { func TestNatsClient_PublishAfterClose(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { - client := CreateLocalNatsClientForTest(t) + _, _, client := CreateLocalNatsClientForTest(t) testNatsClient_PublishAfterClose(t, client) }) @@ -132,7 +140,7 @@ func testNatsClient_SubscribeAfterClose(t *testing.T, client NatsClient) { func TestNatsClient_SubscribeAfterClose(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { - client := CreateLocalNatsClientForTest(t) + _, _, client := CreateLocalNatsClientForTest(t) testNatsClient_SubscribeAfterClose(t, client) }) @@ -155,8 +163,38 @@ func testNatsClient_BadSubjects(t *testing.T, client NatsClient) { func TestNatsClient_BadSubjects(t *testing.T) { CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { - client := CreateLocalNatsClientForTest(t) + _, _, client := CreateLocalNatsClientForTest(t) testNatsClient_BadSubjects(t, client) }) } + +func TestNatsClient_MaxReconnects(t *testing.T) { + CatchLogForTest(t) + ensureNoGoroutinesLeak(t, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + reconnectWait := 5 * time.Millisecond + server, port, client := CreateLocalNatsClientForTest(t, + nats.ReconnectWait(reconnectWait), + nats.ReconnectJitter(0, 0), + ) + c, ok := client.(*natsClient) + require.True(ok, "wrong class: %T", client) + require.True(c.conn.IsConnected(), "not connected initially") + assert.Equal(server.ID(), c.conn.ConnectedServerId()) + + server.Shutdown() + server.WaitForShutdown() + + // The NATS client tries to reconnect a maximum of 100 times by default. + time.Sleep(time.Second + (100 * reconnectWait)) + require.False(c.conn.IsConnected(), "should be disconnected after server shutdown") + + server, _ = startLocalNatsServerPort(t, port) + + time.Sleep(time.Second) + require.True(c.conn.IsConnected(), "not connected after restart") + assert.Equal(server.ID(), c.conn.ConnectedServerId()) + }) +}