diff --git a/backend_server_test.go b/backend_server_test.go index b5e3b42..b68b923 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -1100,21 +1100,17 @@ func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) { } } - ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second+100*time.Millisecond) + ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel2() - if msg1_c, _ := client1.RunUntilMessage(ctx2); msg1_c != nil { - if in_call_2, err := checkMessageParticipantsInCall(msg1_c); assert.NoError(err) { - assert.Len(in_call_2.Users, 2) - } + if msg1_c, err := client1.RunUntilMessage(ctx2); !assert.ErrorIs(err, context.DeadlineExceeded) { + assert.Fail("should have timeout out", "received %+v", msg1_c) } - ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second+100*time.Millisecond) + ctx3, cancel3 := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel3() - if msg2_c, _ := client2.RunUntilMessage(ctx3); msg2_c != nil { - if in_call_2, err := checkMessageParticipantsInCall(msg2_c); assert.NoError(err) { - assert.Len(in_call_2.Users, 2) - } + if msg2_c, err := client2.RunUntilMessage(ctx3); !assert.ErrorIs(err, context.DeadlineExceeded) { + assert.Fail("should have timeout out", "received %+v", msg2_c) } } diff --git a/backoff.go b/backoff.go index 5b49521..27af0f9 100644 --- a/backoff.go +++ b/backoff.go @@ -37,6 +37,8 @@ type exponentialBackoff struct { initial time.Duration maxWait time.Duration nextWait time.Duration + + getContextWithTimeout func(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) } func NewExponentialBackoff(initial time.Duration, maxWait time.Duration) (Backoff, error) { @@ -52,6 +54,8 @@ func NewExponentialBackoff(initial time.Duration, maxWait time.Duration) (Backof maxWait: maxWait, nextWait: initial, + + getContextWithTimeout: context.WithTimeout, }, nil } @@ -64,7 +68,7 @@ func (b *exponentialBackoff) NextWait() time.Duration { } func (b *exponentialBackoff) Wait(ctx context.Context) { - waiter, cancel := context.WithTimeout(ctx, b.nextWait) + waiter, cancel := b.getContextWithTimeout(ctx, b.nextWait) defer cancel() b.nextWait = b.nextWait * 2 diff --git a/backoff_test.go b/backoff_test.go index 87de048..7270c90 100644 --- a/backoff_test.go +++ b/backoff_test.go @@ -31,7 +31,6 @@ import ( ) func TestBackoff_Exponential(t *testing.T) { - t.Parallel() assert := assert.New(t) minWait := 100 * time.Millisecond backoff, err := NewExponentialBackoff(minWait, 500*time.Millisecond) @@ -47,14 +46,27 @@ func TestBackoff_Exponential(t *testing.T) { for _, wait := range waitTimes { assert.Equal(wait, backoff.NextWait()) - - a := time.Now() + backoff.(*exponentialBackoff).getContextWithTimeout = func(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + assert.Equal(wait, timeout) + return context.WithTimeout(parent, time.Millisecond) + } backoff.Wait(context.Background()) - b := time.Now() - assert.GreaterOrEqual(b.Sub(a), wait) } backoff.Reset() + backoff.(*exponentialBackoff).getContextWithTimeout = func(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + assert.Equal(minWait, timeout) + return context.WithTimeout(parent, time.Millisecond) + } + backoff.Wait(context.Background()) +} + +func TestBackoff_ExponentialRealSleep(t *testing.T) { + assert := assert.New(t) + minWait := 100 * time.Millisecond + backoff, err := NewExponentialBackoff(minWait, 500*time.Millisecond) + require.NoError(t, err) + a := time.Now() backoff.Wait(context.Background()) b := time.Now() diff --git a/concurrentmap.go b/concurrentmap.go index 1a4da0d..920fac4 100644 --- a/concurrentmap.go +++ b/concurrentmap.go @@ -26,13 +26,13 @@ import ( ) type ConcurrentStringStringMap struct { - sync.Mutex - d map[string]string + mu sync.RWMutex + d map[string]string } func (m *ConcurrentStringStringMap) Set(key, value string) { - m.Lock() - defer m.Unlock() + m.mu.Lock() + defer m.mu.Unlock() if m.d == nil { m.d = make(map[string]string) } @@ -40,26 +40,26 @@ func (m *ConcurrentStringStringMap) Set(key, value string) { } func (m *ConcurrentStringStringMap) Get(key string) (string, bool) { - m.Lock() - defer m.Unlock() + m.mu.RLock() + defer m.mu.RUnlock() s, found := m.d[key] return s, found } func (m *ConcurrentStringStringMap) Del(key string) { - m.Lock() - defer m.Unlock() + m.mu.Lock() + defer m.mu.Unlock() delete(m.d, key) } func (m *ConcurrentStringStringMap) Len() int { - m.Lock() - defer m.Unlock() + m.mu.RLock() + defer m.mu.RUnlock() return len(m.d) } func (m *ConcurrentStringStringMap) Clear() { - m.Lock() - defer m.Unlock() + m.mu.Lock() + defer m.mu.Unlock() m.d = nil } diff --git a/concurrentmap_test.go b/concurrentmap_test.go index 276b038..990a520 100644 --- a/concurrentmap_test.go +++ b/concurrentmap_test.go @@ -86,8 +86,11 @@ func TestConcurrentStringStringMap(t *testing.T) { for y := 0; y < count; y = y + 1 { value := rnd + "-" + strconv.Itoa(y) m.Set(key, value) - if v, found := m.Get(key); !assert.True(found, "Expected entry for key %s", key) || - !assert.Equal(value, v, "Unexpected value for key %s", key) { + if v, found := m.Get(key); !found { + assert.True(found, "Expected entry for key %s", key) + return + } else if v != value { + assert.Equal(value, v, "Unexpected value for key %s", key) return } } diff --git a/hub_test.go b/hub_test.go index b364623..b8cda73 100644 --- a/hub_test.go +++ b/hub_test.go @@ -1302,36 +1302,48 @@ func TestSessionIdsUnordered(t *testing.T) { assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) + var mu sync.Mutex publicSessionIds := make([]string, 0) + var wg sync.WaitGroup for i := 0; i < 20; i++ { - client := NewTestClient(t, server, hub) - defer client.CloseWithBye() + wg.Add(1) + go func() { + defer wg.Done() + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() - require.NoError(client.SendHello(testDefaultUserId)) + require.NoError(client.SendHello(testDefaultUserId)) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() - if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { - assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) - assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.Equal(testDefaultUserId, hello.Hello.UserId, "%+v", hello.Hello) + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello.Hello) - data := hub.decodePublicSessionId(hello.Hello.SessionId) - if !assert.NotNil(data, "Could not decode session id: %s", hello.Hello.SessionId) { - break + data := hub.decodePublicSessionId(hello.Hello.SessionId) + if !assert.NotNil(data, "Could not decode session id: %s", hello.Hello.SessionId) { + return + } + + hub.mu.RLock() + session := hub.sessions[data.Sid] + hub.mu.RUnlock() + if !assert.NotNil(session, "Could not get session for id %+v", data) { + return + } + + mu.Lock() + publicSessionIds = append(publicSessionIds, session.PublicId()) + mu.Unlock() } - - hub.mu.RLock() - session := hub.sessions[data.Sid] - hub.mu.RUnlock() - if !assert.NotNil(session, "Could not get session for id %+v", data) { - break - } - - publicSessionIds = append(publicSessionIds, session.PublicId()) - } + }() } + wg.Wait() + + mu.Lock() + defer mu.Unlock() require.NotEmpty(publicSessionIds, "no session ids decoded") larger := 0 diff --git a/natsclient_test.go b/natsclient_test.go index 362895b..bc92bfb 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -174,7 +174,7 @@ func TestNatsClient_MaxReconnects(t *testing.T) { ensureNoGoroutinesLeak(t, func(t *testing.T) { assert := assert.New(t) require := require.New(t) - reconnectWait := 5 * time.Millisecond + reconnectWait := time.Millisecond server, port, client := CreateLocalNatsClientForTest(t, nats.ReconnectWait(reconnectWait), nats.ReconnectJitter(0, 0), @@ -188,12 +188,18 @@ func TestNatsClient_MaxReconnects(t *testing.T) { server.WaitForShutdown() // The NATS client tries to reconnect a maximum of 100 times by default. - time.Sleep(time.Second + (100 * reconnectWait)) + time.Sleep(100 * reconnectWait) + for i := 0; i < 1000 && c.conn.IsConnected(); i++ { + time.Sleep(time.Millisecond) + } require.False(c.conn.IsConnected(), "should be disconnected after server shutdown") server, _ = startLocalNatsServerPort(t, port) - time.Sleep(time.Second) + // Wait for automatic reconnection + for i := 0; i < 1000 && !c.conn.IsConnected(); i++ { + time.Sleep(time.Millisecond) + } require.True(c.conn.IsConnected(), "not connected after restart") assert.Equal(server.ID(), c.conn.ConnectedServerId()) }) diff --git a/testutils_test.go b/testutils_test.go index f2d507a..8aaa076 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -57,8 +57,15 @@ func ensureNoGoroutinesLeak(t *testing.T, f func(t *testing.T)) { profile := pprof.Lookup("goroutine") // Give time for things to settle before capturing the number of // go routines - time.Sleep(500 * time.Millisecond) - before := profile.Count() + var before int + timeout := time.Now().Add(time.Second) + for time.Now().Before(timeout) { + before = profile.Count() + time.Sleep(10 * time.Millisecond) + if profile.Count() == before { + break + } + } var prev bytes.Buffer dumpGoroutines("Before:", &prev) @@ -67,7 +74,7 @@ func ensureNoGoroutinesLeak(t *testing.T, f func(t *testing.T)) { var after int // Give time for things to settle before capturing the number of // go routines - timeout := time.Now().Add(time.Second) + timeout = time.Now().Add(time.Second) for time.Now().Before(timeout) { after = profile.Count() if after == before {