diff --git a/hub.go b/hub.go index 427a862..37fd2de 100644 --- a/hub.go +++ b/hub.go @@ -1001,8 +1001,8 @@ func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) { h.processRegister(client, message, backend, auth) } -func (h *Hub) disconnectByRoomSessionId(roomSessionId string, backend *Backend) { - sessionId, err := h.roomSessions.GetSessionId(roomSessionId) +func (h *Hub) disconnectByRoomSessionId(ctx context.Context, roomSessionId string, backend *Backend) { + sessionId, err := h.roomSessions.LookupSessionId(ctx, roomSessionId) if err == ErrNoSuchRoomSession { return } else if err != nil { @@ -1116,7 +1116,10 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { if message.Room.SessionId != "" { // There can only be one connection per Nextcloud Talk session, // disconnect any other connections without sending a "leave" event. - h.disconnectByRoomSessionId(message.Room.SessionId, session.Backend()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + h.disconnectByRoomSessionId(ctx, message.Room.SessionId, session.Backend()) } } diff --git a/hub_test.go b/hub_test.go index 8e9db47..b8502f1 100644 --- a/hub_test.go +++ b/hub_test.go @@ -2848,17 +2848,28 @@ func TestRoomParticipantsListUpdateWhileDisconnected(t *testing.T) { } func TestClientTakeoverRoomSession(t *testing.T) { - for _, backend := range eventBackendsForTest { - t.Run(backend, func(t *testing.T) { + for _, subtest := range clusteredTests { + t.Run(subtest, func(t *testing.T) { RunTestClientTakeoverRoomSession(t) }) } } func RunTestClientTakeoverRoomSession(t *testing.T) { - hub, _, _, server := CreateHubForTest(t) + var hub1 *Hub + var hub2 *Hub + var server1 *httptest.Server + var server2 *httptest.Server + if isLocalTest(t) { + hub1, _, _, server1 = CreateHubForTest(t) - client1 := NewTestClient(t, server, hub) + hub2 = hub1 + server2 = server1 + } else { + hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) + } + + client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() if err := client1.SendHello(testDefaultUserId + "1"); err != nil { @@ -2882,15 +2893,15 @@ func RunTestClientTakeoverRoomSession(t *testing.T) { t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) } - if hubRoom := hub.getRoom(roomId); hubRoom == nil { + if hubRoom := hub1.getRoom(roomId); hubRoom == nil { t.Fatalf("Room %s does not exist", roomId) } - if session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId); session1 == nil { + if session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId); session1 == nil { t.Fatalf("There should be a session %s", hello1.Hello.SessionId) } - client3 := NewTestClient(t, server, hub) + client3 := NewTestClient(t, server2, hub2) defer client3.CloseWithBye() if err := client3.SendHello(testDefaultUserId + "3"); err != nil { @@ -2911,7 +2922,7 @@ func RunTestClientTakeoverRoomSession(t *testing.T) { // Wait until both users have joined. WaitForUsersJoined(ctx, t, client1, hello1, client3, hello3) - client2 := NewTestClient(t, server, hub) + client2 := NewTestClient(t, server2, hub2) defer client2.CloseWithBye() if err := client2.SendHello(testDefaultUserId + "2"); err != nil { @@ -2948,7 +2959,7 @@ func RunTestClientTakeoverRoomSession(t *testing.T) { } // The first session has been closed - if session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId); session1 != nil { + if session1 := hub1.GetSessionByPublicId(hello1.Hello.SessionId); session1 != nil { t.Errorf("The session %s should have been removed", hello1.Hello.SessionId) } @@ -2958,26 +2969,43 @@ func RunTestClientTakeoverRoomSession(t *testing.T) { t.Error(err) } - // No message about the closing is sent to the new connection. - ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel2() + if isLocalTest(t) { + // No message about the closing is sent to the new connection. + ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel2() - if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { - t.Error(err) - } else if message != nil { - t.Errorf("Expected no message, got %+v", message) - } + if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { + t.Error(err) + } else if message != nil { + t.Errorf("Expected no message, got %+v", message) + } - // The permanently connected client will receive a "left" event from the - // overridden session and a "joined" for the new session. - if err := client3.RunUntilLeft(ctx, hello1.Hello); err != nil { - t.Error(err) - } - if err := client3.RunUntilJoined(ctx, hello2.Hello); err != nil { - t.Error(err) - } + // The permanently connected client will receive a "left" event from the + // overridden session and a "joined" for the new session. In that order as + // both were on the same server. + if err := client3.RunUntilLeft(ctx, hello1.Hello); err != nil { + t.Error(err) + } + if err := client3.RunUntilJoined(ctx, hello2.Hello); err != nil { + t.Error(err) + } + } else { + // In the clustered case, the new connection will receive a "leave" event + // due to the asynchronous events. + if err := client2.RunUntilLeft(ctx, hello1.Hello); err != nil { + t.Error(err) + } - time.Sleep(time.Second) + // The permanently connected client will first a "joined" event from the new + // session (on the same server) and a "left" from the session on the remote + // server (asynchronously). + if err := client3.RunUntilJoined(ctx, hello2.Hello); err != nil { + t.Error(err) + } + if err := client3.RunUntilLeft(ctx, hello1.Hello); err != nil { + t.Error(err) + } + } } func TestClientSendOfferPermissions(t *testing.T) { diff --git a/room.go b/room.go index 16f7dcb..fbcaf7f 100644 --- a/room.go +++ b/room.go @@ -452,6 +452,8 @@ func (r *Room) RemoveSession(session Session) bool { return true } + // Still need to publish an event so sessions on other servers get notified. + r.PublishSessionLeft(session) r.hub.removeRoom(r) r.statsRoomSessionsCurrent.Delete(prometheus.Labels{"clienttype": HelloClientTypeClient}) r.statsRoomSessionsCurrent.Delete(prometheus.Labels{"clienttype": HelloClientTypeInternal})