diff --git a/hub.go b/hub.go index 57a7f9a..7a52774 100644 --- a/hub.go +++ b/hub.go @@ -986,7 +986,7 @@ func (h *Hub) disconnectByRoomSessionId(roomSessionId string) { session.Close() } -func (h *Hub) sendRoom(client *Client, message *ClientMessage, room *Room) bool { +func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Room) bool { response := &ServerMessage{ Type: "room", } @@ -1003,7 +1003,7 @@ func (h *Hub) sendRoom(client *Client, message *ClientMessage, room *Room) bool Properties: room.properties, } } - return client.SendMessage(response) + return session.SendMessage(response) } func (h *Hub) processRoom(client *Client, message *ClientMessage) { @@ -1017,7 +1017,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { // We can handle leaving a room directly. if session.LeaveRoom(true) != nil { // User was in a room before, so need to notify about leaving it. - h.sendRoom(client, message, nil) + h.sendRoom(session, message, nil) } if session.UserId() == "" && session.ClientType() != HelloClientTypeInternal { h.startWaitAnonymousClientRoom(client) @@ -1054,7 +1054,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { } request := NewBackendClientRoomRequest(roomId, session.UserId(), sessionId) if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendUrl(), request, &room); err != nil { - client.SendMessage(message.NewWrappedErrorServerMessage(err)) + session.SendMessage(message.NewWrappedErrorServerMessage(err)) return } @@ -1067,7 +1067,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { } } - h.processJoinRoom(client, message, &room) + h.processJoinRoom(session, message, &room) } func (h *Hub) getRoomForBackend(id string, backend *Backend) *Room { @@ -1097,18 +1097,12 @@ func (h *Hub) createRoom(id string, properties *json.RawMessage, backend *Backen return room, nil } -func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *BackendClientResponse) { - session := client.GetSession() - if session == nil { - // Client disconnected while waiting for join room response. - return - } - +func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, room *BackendClientResponse) { if room.Type == "error" { - client.SendMessage(message.NewErrorServerMessage(room.Error)) + session.SendMessage(message.NewErrorServerMessage(room.Error)) return } else if room.Type != "room" { - client.SendMessage(message.NewErrorServerMessage(RoomJoinFailed)) + session.SendMessage(message.NewErrorServerMessage(RoomJoinFailed)) return } @@ -1117,9 +1111,9 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back roomId := room.Room.RoomId internalRoomId := getRoomIdForBackend(roomId, session.Backend()) if err := session.SubscribeRoomNats(h.nats, roomId, message.Room.SessionId); err != nil { - client.SendMessage(message.NewWrappedErrorServerMessage(err)) + session.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. - h.sendRoom(client, nil, nil) + h.sendRoom(session, nil, nil) return } @@ -1129,28 +1123,30 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back var err error if r, err = h.createRoom(roomId, room.Room.Properties, session.Backend()); err != nil { h.ru.Unlock() - client.SendMessage(message.NewWrappedErrorServerMessage(err)) + session.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. session.UnsubscribeRoomNats() - h.sendRoom(client, nil, nil) + h.sendRoom(session, nil, nil) return } } h.ru.Unlock() h.mu.Lock() - // The client now joined a room, don't expire him if he is anonymous. - delete(h.anonymousClients, client) + if client := session.GetClient(); client != nil { + // The client now joined a room, don't expire him if he is anonymous. + delete(h.anonymousClients, client) + } h.mu.Unlock() session.SetRoom(r) if room.Room.Permissions != nil { session.SetPermissions(*room.Room.Permissions) } - h.sendRoom(client, message, r) - h.notifyUserJoinedRoom(r, client, session, room.Room.Session) + h.sendRoom(session, message, r) + h.notifyUserJoinedRoom(r, session, room.Room.Session) } -func (h *Hub) notifyUserJoinedRoom(room *Room, client *Client, session Session, sessionData *json.RawMessage) { +func (h *Hub) notifyUserJoinedRoom(room *Room, session *ClientSession, sessionData *json.RawMessage) { // Register session with the room if sessions := room.AddSession(session, sessionData); len(sessions) > 0 { events := make([]*EventServerMessageSessionEntry, 0, len(sessions)) @@ -1171,7 +1167,7 @@ func (h *Hub) notifyUserJoinedRoom(room *Room, client *Client, session Session, } // No need to send through NATS, the session is connected locally. - client.SendMessage(msg) + session.SendMessage(msg) } } @@ -1491,14 +1487,14 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendUrl(), request, &response); err != nil { log.Printf("Could not join virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) reply := message.NewErrorServerMessage(NewError("add_failed", "Could not join virtual session.")) - client.SendMessage(reply) + session.SendMessage(reply) return } if response.Type == "error" { log.Printf("Could not join virtual session %s at backend %s: %+v", virtualSessionId, session.BackendUrl(), response.Error) reply := message.NewErrorServerMessage(NewError("add_failed", response.Error.Error())) - client.SendMessage(reply) + session.SendMessage(reply) return } } else { @@ -1507,7 +1503,7 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendUrl(), request, &response); err != nil { log.Printf("Could not add virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) reply := message.NewErrorServerMessage(NewError("add_failed", "Could not add virtual session.")) - client.SendMessage(reply) + session.SendMessage(reply) return } } @@ -1763,7 +1759,7 @@ func (h *Hub) processRoomDeleted(message *BackendServerRoomRequest) { switch sess := session.(type) { case *ClientSession: if client := sess.GetClient(); client != nil { - h.sendRoom(client, nil, nil) + h.sendRoom(sess, nil, nil) } } } diff --git a/hub_test.go b/hub_test.go index 9b241e7..654ed00 100644 --- a/hub_test.go +++ b/hub_test.go @@ -254,7 +254,10 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re t.Fatalf("Expected an room backend request, got %+v", request) } - if request.Room.RoomId == "test-room-takeover-room-session" { + switch request.Room.RoomId { + case "test-room-slow": + time.Sleep(100 * time.Millisecond) + case "test-room-takeover-room-session": // Additional checks for testcase "TestClientTakeoverRoomSession" if request.Room.Action == "leave" && request.Room.UserId == "test-userid1" { t.Errorf("Should not receive \"leave\" event for first user, received %+v", request.Room) @@ -1754,6 +1757,95 @@ func TestJoinMultiple(t *testing.T) { } } +func TestJoinRoomSwitchClient(t *testing.T) { + hub, _, _, server, shutdown := CreateHubForTest(t) + defer shutdown() + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + if err := client.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + hello, err := client.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + // Join room by id. + roomId := "test-room-slow" + msg := &ClientMessage{ + Id: "ABCD", + Type: "room", + Room: &RoomClientMessage{ + RoomId: roomId, + SessionId: roomId + "-" + hello.Hello.SessionId, + }, + } + if err := client.WriteJSON(msg); err != nil { + t.Fatal(err) + } + // Wait a bit to make sure request is sent before closing client. + time.Sleep(1 * time.Millisecond) + client.Close() + if err := client.WaitForClientRemoved(ctx); err != nil { + t.Fatal(err) + } + + // The client needs some time to reconnect. + time.Sleep(200 * time.Millisecond) + + client2 := NewTestClient(t, server, hub) + defer client2.CloseWithBye() + if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { + t.Fatal(err) + } + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Error(err) + } else { + if hello2.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) + } + if hello2.Hello.SessionId != hello.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) + } + if hello2.Hello.ResumeId != hello.Hello.ResumeId { + t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) + } + } + + room, err := client2.RunUntilMessage(ctx) + if err != nil { + t.Fatal(err) + } + if err := checkUnexpectedClose(err); err != nil { + t.Fatal(err) + } + if err := checkMessageType(room, "room"); err != nil { + t.Fatal(err) + } + if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + + // We will receive a "joined" event. + if err := client2.RunUntilJoined(ctx, hello.Hello); err != nil { + t.Error(err) + } + + // Leave room. + if room, err := client2.JoinRoom(ctx, ""); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != "" { + t.Fatalf("Expected empty room, got %s", room.Room.RoomId) + } +} + func TestGetRealUserIP(t *testing.T) { REMOTE_ATTR := "192.168.1.2"