diff --git a/clientsession.go b/clientsession.go index a10e262..1b67ab1 100644 --- a/clientsession.go +++ b/clientsession.go @@ -86,6 +86,9 @@ type ClientSession struct { hasPendingParticipantsUpdate bool virtualSessions map[*VirtualSession]bool + + seenJoinedLock sync.Mutex + seenJoinedEvents map[string]bool } func NewClientSession(hub *Hub, privateId string, publicId string, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) { @@ -335,6 +338,10 @@ func (s *ClientSession) SetRoom(room *Room) { } else { atomic.StoreInt64(&s.roomJoinTime, 0) } + + s.seenJoinedLock.Lock() + defer s.seenJoinedLock.Unlock() + s.seenJoinedEvents = nil } func (s *ClientSession) GetRoom() *Room { @@ -1153,6 +1160,29 @@ func filterDisplayNames(events []*EventServerMessageSessionEntry) []*EventServer return result } +func (s *ClientSession) filterDuplicateJoin(entries []*EventServerMessageSessionEntry) []*EventServerMessageSessionEntry { + s.seenJoinedLock.Lock() + defer s.seenJoinedLock.Unlock() + + // Due to the asynchronous events, a session might received a "Joined" event + // for the same (other) session twice, so filter these out on a per-session + // level. + result := make([]*EventServerMessageSessionEntry, 0, len(entries)) + for _, e := range entries { + if s.seenJoinedEvents[e.SessionId] { + log.Printf("Session %s got duplicate joined event for %s, ignoring", s.publicId, e.SessionId) + continue + } + + if s.seenJoinedEvents == nil { + s.seenJoinedEvents = make(map[string]bool) + } + s.seenJoinedEvents[e.SessionId] = true + result = append(result, e) + } + return result +} + func (s *ClientSession) filterMessage(message *ServerMessage) *ServerMessage { switch message.Type { case "event": @@ -1177,18 +1207,47 @@ func (s *ClientSession) filterMessage(message *ServerMessage) *ServerMessage { case "room": switch message.Event.Type { case "join": - if s.HasPermission(PERMISSION_HIDE_DISPLAYNAMES) { + join := s.filterDuplicateJoin(message.Event.Join) + if len(join) == 0 { + return nil + } + copied := false + if len(join) != len(message.Event.Join) { // Create unique copy of message for only this client. + copied = true message = &ServerMessage{ Id: message.Id, Type: message.Type, Event: &EventServerMessage{ Type: message.Event.Type, Target: message.Event.Target, - Join: filterDisplayNames(message.Event.Join), + Join: join, }, } } + + if s.HasPermission(PERMISSION_HIDE_DISPLAYNAMES) { + if copied { + message.Event.Join = filterDisplayNames(message.Event.Join) + } else { + message = &ServerMessage{ + Id: message.Id, + Type: message.Type, + Event: &EventServerMessage{ + Type: message.Event.Type, + Target: message.Event.Target, + Join: filterDisplayNames(message.Event.Join), + }, + } + } + } + case "leave": + s.seenJoinedLock.Lock() + defer s.seenJoinedLock.Unlock() + + for _, e := range message.Event.Leave { + delete(s.seenJoinedEvents, e) + } case "message": if message.Event.Message == nil || message.Event.Message.Data == nil || len(*message.Event.Message.Data) == 0 || !s.HasPermission(PERMISSION_HIDE_DISPLAYNAMES) { return message @@ -1265,7 +1324,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage { // Can happen mostly during tests where an older room async message // could be received by a subscriber that joined after it was sent. if joined := s.getRoomJoinTime(); joined.IsZero() || msg.SendTime.Before(joined) { - log.Printf("Message %+v was sent before room was joined, ignoring", msg.Message) + log.Printf("Message %+v was sent on %s before room was joined on %s, ignoring", msg.Message, msg.SendTime, joined) return nil } } diff --git a/hub_test.go b/hub_test.go index 6d70cc1..12d0ec1 100644 --- a/hub_test.go +++ b/hub_test.go @@ -2401,51 +2401,12 @@ func TestClientMessageToUserIdMultipleSessions(t *testing.T) { func WaitForUsersJoined(ctx context.Context, t *testing.T, client1 *TestClient, hello1 *ServerMessage, client2 *TestClient, hello2 *ServerMessage) { // We will receive "joined" events for all clients. The ordering is not // defined as messages are processed and sent by asynchronous event handlers. - msg1_1, err := client1.RunUntilMessage(ctx) - if err != nil { + if err := client1.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { t.Error(err) } - msg1_2, err := client1.RunUntilMessage(ctx) - if err != nil { + if err := client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello); err != nil { t.Error(err) } - msg2_1, err := client2.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - msg2_2, err := client2.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } - - if err := client1.checkMessageJoined(msg1_1, hello1.Hello); err != nil { - // Ordering is "joined" from client 2, then from client 1 - if err := client1.checkMessageJoined(msg1_1, hello2.Hello); err != nil { - t.Error(err) - } - if err := client1.checkMessageJoined(msg1_2, hello1.Hello); err != nil { - t.Error(err) - } - } else { - // Ordering is "joined" from client 1, then from client 2 - if err := client1.checkMessageJoined(msg1_2, hello2.Hello); err != nil { - t.Error(err) - } - } - if err := client2.checkMessageJoined(msg2_1, hello1.Hello); err != nil { - // Ordering is "joined" from client 2, then from client 1 - if err := client2.checkMessageJoined(msg2_1, hello2.Hello); err != nil { - t.Error(err) - } - if err := client2.checkMessageJoined(msg2_2, hello1.Hello); err != nil { - t.Error(err) - } - } else { - // Ordering is "joined" from client 1, then from client 2 - if err := client2.checkMessageJoined(msg2_2, hello2.Hello); err != nil { - t.Error(err) - } - } } func TestClientMessageToRoom(t *testing.T) { diff --git a/testclient_test.go b/testclient_test.go index 86560c0..2f2fde7 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -770,7 +770,7 @@ func (c *TestClient) RunUntilJoinedAndReturn(ctx context.Context, hello ...*Hell for len(hellos) > 0 { message, err := c.RunUntilMessage(ctx) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("got error while waiting for %+v: %w", hellos, err) } if err := checkMessageType(message, "event"); err != nil {