diff --git a/natsclient_loopback_test.go b/natsclient_loopback_test.go index 3cb35b1..ff8994c 100644 --- a/natsclient_loopback_test.go +++ b/natsclient_loopback_test.go @@ -60,7 +60,7 @@ func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { } func TestLoopbackNatsClient_Subscribe(t *testing.T) { - ensureNoGoroutinesLeak(t, func() { + ensureNoGoroutinesLeak(t, func(t *testing.T) { client := CreateLoopbackNatsClientForTest(t) testNatsClient_Subscribe(t, client) @@ -68,7 +68,7 @@ func TestLoopbackNatsClient_Subscribe(t *testing.T) { } func TestLoopbackClient_PublishAfterClose(t *testing.T) { - ensureNoGoroutinesLeak(t, func() { + ensureNoGoroutinesLeak(t, func(t *testing.T) { client := CreateLoopbackNatsClientForTest(t) testNatsClient_PublishAfterClose(t, client) @@ -76,7 +76,7 @@ func TestLoopbackClient_PublishAfterClose(t *testing.T) { } func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { - ensureNoGoroutinesLeak(t, func() { + ensureNoGoroutinesLeak(t, func(t *testing.T) { client := CreateLoopbackNatsClientForTest(t) testNatsClient_SubscribeAfterClose(t, client) @@ -84,7 +84,7 @@ func TestLoopbackClient_SubscribeAfterClose(t *testing.T) { } func TestLoopbackClient_BadSubjects(t *testing.T) { - ensureNoGoroutinesLeak(t, func() { + ensureNoGoroutinesLeak(t, func(t *testing.T) { client := CreateLoopbackNatsClientForTest(t) testNatsClient_BadSubjects(t, client) diff --git a/natsclient_test.go b/natsclient_test.go index 62defca..b72d291 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -105,7 +105,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { } func TestNatsClient_Subscribe(t *testing.T) { - ensureNoGoroutinesLeak(t, func() { + ensureNoGoroutinesLeak(t, func(t *testing.T) { client := CreateLocalNatsClientForTest(t) testNatsClient_Subscribe(t, client) @@ -121,7 +121,7 @@ func testNatsClient_PublishAfterClose(t *testing.T, client NatsClient) { } func TestNatsClient_PublishAfterClose(t *testing.T) { - ensureNoGoroutinesLeak(t, func() { + ensureNoGoroutinesLeak(t, func(t *testing.T) { client := CreateLocalNatsClientForTest(t) testNatsClient_PublishAfterClose(t, client) @@ -138,7 +138,7 @@ func testNatsClient_SubscribeAfterClose(t *testing.T, client NatsClient) { } func TestNatsClient_SubscribeAfterClose(t *testing.T) { - ensureNoGoroutinesLeak(t, func() { + ensureNoGoroutinesLeak(t, func(t *testing.T) { client := CreateLocalNatsClientForTest(t) testNatsClient_SubscribeAfterClose(t, client) @@ -160,7 +160,7 @@ func testNatsClient_BadSubjects(t *testing.T, client NatsClient) { } func TestNatsClient_BadSubjects(t *testing.T) { - ensureNoGoroutinesLeak(t, func() { + ensureNoGoroutinesLeak(t, func(t *testing.T) { client := CreateLocalNatsClientForTest(t) testNatsClient_BadSubjects(t, client) diff --git a/testutils_test.go b/testutils_test.go index b789a2d..bcc781f 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -23,19 +23,38 @@ package signaling import ( "os" + "os/signal" "runtime/pprof" + "sync" "testing" "time" ) -func ensureNoGoroutinesLeak(t *testing.T, f func()) { +var listenSignalOnce sync.Once + +func ensureNoGoroutinesLeak(t *testing.T, f func(t *testing.T)) { + t.Helper() + + // The signal package will start a goroutine the first time "signal.Notify" + // is called. Do so outside the function under test so the signal goroutine + // will not be shown as "leaking". + listenSignalOnce.Do(func() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + go func() { + for { + <-ch + } + }() + }) + profile := pprof.Lookup("goroutine") // Give time for things to settle before capturing the number of // go routines time.Sleep(500 * time.Millisecond) before := profile.Count() - f() + t.Run("leakcheck", f) var after int // Give time for things to settle before capturing the number of @@ -50,6 +69,6 @@ func ensureNoGoroutinesLeak(t *testing.T, f func()) { if after != before { profile.WriteTo(os.Stderr, 2) // nolint - t.Fatalf("Number of Go routines has changed in %s from %d to %d", t.Name(), before, after) + t.Fatalf("Number of Go routines has changed from %d to %d", before, after) } }