diff --git a/api_async.go b/api_async.go index bc49dc7..a78adff 100644 --- a/api_async.go +++ b/api_async.go @@ -44,7 +44,8 @@ type AsyncMessage struct { type AsyncRoomMessage struct { Type string `json:"type"` - SessionId string `json:"sessionid,omitempty"` + SessionId string `json:"sessionid,omitempty"` + ClientType string `json:"clienttype,omitempty"` } type SendOfferMessage struct { diff --git a/api_signaling.go b/api_signaling.go index 930663c..0373bb5 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -661,6 +661,14 @@ type RoomEventServerMessage struct { All bool `json:"all,omitempty"` } +func (m *RoomEventServerMessage) String() string { + data, err := json.Marshal(m) + if err != nil { + return fmt.Sprintf("Could not serialize %#v: %s", m, err) + } + return string(data) +} + const ( DisinviteReasonDisinvited = "disinvited" DisinviteReasonDeleted = "deleted" @@ -714,6 +722,14 @@ type EventServerMessage struct { Message *RoomEventMessage `json:"message,omitempty"` } +func (m *EventServerMessage) String() string { + data, err := json.Marshal(m) + if err != nil { + return fmt.Sprintf("Could not serialize %#v: %s", m, err) + } + return string(data) +} + type EventServerMessageSessionEntry struct { SessionId string `json:"sessionid"` UserId string `json:"userid"` diff --git a/clientsession.go b/clientsession.go index ee302a1..5fa381d 100644 --- a/clientsession.go +++ b/clientsession.go @@ -1316,3 +1316,14 @@ func (s *ClientSession) RemoveVirtualSession(session *VirtualSession) { delete(s.virtualSessions, session) s.mu.Unlock() } + +func (s *ClientSession) GetVirtualSessions() []*VirtualSession { + s.mu.Lock() + defer s.mu.Unlock() + + result := make([]*VirtualSession, 0, len(s.virtualSessions)) + for session := range s.virtualSessions { + result = append(result, session) + } + return result +} diff --git a/hub.go b/hub.go index d672886..2f34d82 100644 --- a/hub.go +++ b/hub.go @@ -1306,6 +1306,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { } else { subject = "session." + msg.Recipient.SessionId recipientSessionId = msg.Recipient.SessionId + serverRecipient = &msg.Recipient } case RecipientTypeUser: if msg.Recipient.UserId != "" { @@ -1501,8 +1502,12 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { SessionId: virtualSession.SessionId(), } } + } else { + serverRecipient = &msg.Recipient } h.mu.RUnlock() + } else { + serverRecipient = &msg.Recipient } case RecipientTypeUser: if msg.Recipient.UserId != "" { @@ -1600,6 +1605,14 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { virtualSessionId := GetVirtualSessionId(session, msg.SessionId) + sess, err := NewVirtualSession(session, privateSessionId, publicSessionId, sessionIdData, msg) + if err != nil { + log.Printf("Could not create virtual session %s: %s", virtualSessionId, err) + reply := message.NewErrorServerMessage(NewError("add_failed", "Could not create virtual session.")) + session.SendMessage(reply) + return + } + if msg.Options != nil { request := NewBackendClientRoomRequest(room.Id(), msg.UserId, publicSessionId) request.Room.ActorId = msg.Options.ActorId @@ -1608,6 +1621,7 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { var response BackendClientResponse if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendUrl(), request, &response); err != nil { + sess.Close() log.Printf("Could not join virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) reply := message.NewErrorServerMessage(NewError("add_failed", "Could not join virtual session.")) session.SendMessage(reply) @@ -1615,6 +1629,7 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { } if response.Type == "error" { + sess.Close() log.Printf("Could not join virtual session %s at backend %s: %+v", virtualSessionId, session.BackendUrl(), response.Error) reply := message.NewErrorServerMessage(NewError("add_failed", response.Error.Error())) session.SendMessage(reply) @@ -1624,6 +1639,7 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { request := NewBackendClientSessionRequest(room.Id(), "add", publicSessionId, msg) var response BackendClientSessionResponse if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendUrl(), request, &response); err != nil { + sess.Close() log.Printf("Could not add virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) reply := message.NewErrorServerMessage(NewError("add_failed", "Could not add virtual session.")) session.SendMessage(reply) @@ -1631,7 +1647,6 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { } } - sess := NewVirtualSession(session, privateSessionId, publicSessionId, sessionIdData, msg) h.mu.Lock() h.sessions[sessionIdData.Sid] = sess h.virtualSessions[virtualSessionId] = sessionIdData.Sid diff --git a/hub_test.go b/hub_test.go index 4d49759..bb0fbe8 100644 --- a/hub_test.go +++ b/hub_test.go @@ -24,6 +24,7 @@ package signaling import ( "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" @@ -404,12 +405,47 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re return response } +var ( + sessionRequestHander struct { + sync.Mutex + handlers map[*testing.T]func(*BackendClientSessionRequest) + } +) + +func setSessionRequestHandler(t *testing.T, f func(*BackendClientSessionRequest)) { + sessionRequestHander.Lock() + defer sessionRequestHander.Unlock() + if sessionRequestHander.handlers == nil { + sessionRequestHander.handlers = make(map[*testing.T]func(*BackendClientSessionRequest)) + } + if _, found := sessionRequestHander.handlers[t]; !found { + t.Cleanup(func() { + sessionRequestHander.Lock() + defer sessionRequestHander.Unlock() + + delete(sessionRequestHander.handlers, t) + }) + } + sessionRequestHander.handlers[t] = f +} + +func clearSessionRequestHandler(t *testing.T) { // nolint + sessionRequestHander.Lock() + defer sessionRequestHander.Unlock() + + delete(sessionRequestHander.handlers, t) +} + func processSessionRequest(t *testing.T, w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse { if request.Type != "session" || request.Session == nil { t.Fatalf("Expected an session backend request, got %+v", request) } - // TODO(jojo): Evaluate request. + sessionRequestHander.Lock() + defer sessionRequestHander.Unlock() + if f, found := sessionRequestHander.handlers[t]; found { + f(request.Session) + } response := &BackendClientResponse{ Type: "session", @@ -4134,3 +4170,340 @@ func TestClientUnshareScreen(t *testing.T) { t.Fatalf("Publisher %s should be closed", hello1.Hello.SessionId) } } + +func TestVirtualClientSessions(t *testing.T) { + for _, subtest := range clusteredTests { + t.Run(subtest, func(t *testing.T) { + var hub1 *Hub + var hub2 *Hub + var server1 *httptest.Server + var server2 *httptest.Server + if isLocalTest(t) { + hub1, _, _, server1 = CreateHubForTest(t) + + hub2 = hub1 + server2 = server1 + } else { + hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewTestClient(t, server1, hub1) + defer client1.CloseWithBye() + + if err := client1.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + hello1, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + roomId := "test-room" + if _, err := client1.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } + + if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { + t.Error(err) + } + + client2 := NewTestClient(t, server2, hub2) + defer client2.CloseWithBye() + + if err := client2.SendHelloInternal(); err != nil { + t.Fatal(err) + } + + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + session2 := hub2.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession) + if session2 == nil { + t.Fatalf("Session %s does not exist", hello2.Hello.SessionId) + } + + if _, err := client2.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } + + if err := client1.RunUntilJoined(ctx, hello2.Hello); err != nil { + t.Error(err) + } + + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if msg, err := checkMessageParticipantsInCall(msg); err != nil { + t.Error(err) + } else if len(msg.Users) != 1 { + t.Errorf("Expected one user, got %+v", msg) + } else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v { + t.Errorf("Expected internal flag, got %+v", msg) + } else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg) + } else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != 3 { + t.Errorf("Expected inCall flag 3, got %+v", msg) + } + + _, unexpected, err := client2.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello) + if err != nil { + t.Error(err) + } + + if len(unexpected) == 0 { + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else { + unexpected = append(unexpected, msg) + } + } + + if len(unexpected) != 1 { + t.Fatalf("expected one message, got %+v", unexpected) + } + + if msg, err := checkMessageParticipantsInCall(unexpected[0]); err != nil { + t.Error(err) + } else if len(msg.Users) != 1 { + t.Errorf("Expected one user, got %+v", msg) + } else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v { + t.Errorf("Expected internal flag, got %+v", msg) + } else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg) + } else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio { + t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg) + } + + calledCtx, calledCancel := context.WithTimeout(ctx, time.Second) + + virtualSessionId := "virtual-session-id" + virtualUserId := "virtual-user-id" + generatedSessionId := GetVirtualSessionId(session2, virtualSessionId) + + setSessionRequestHandler(t, func(request *BackendClientSessionRequest) { + defer calledCancel() + if request.Action != "add" { + t.Errorf("Expected action add, got %+v", request) + } else if request.RoomId != roomId { + t.Errorf("Expected room id %s, got %+v", roomId, request) + } else if request.SessionId == generatedSessionId { + t.Errorf("Expected generated session id %s, got %+v", generatedSessionId, request) + } else if request.UserId != virtualUserId { + t.Errorf("Expected session id %s, got %+v", virtualUserId, request) + } + }) + + if err := client2.SendInternalAddSession(&AddSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: virtualSessionId, + RoomId: roomId, + }, + UserId: virtualUserId, + Flags: FLAG_MUTED_SPEAKING, + }); err != nil { + t.Fatal(err) + } + <-calledCtx.Done() + if err := calledCtx.Err(); err != nil && !errors.Is(err, context.Canceled) { + t.Fatal(err) + } + + virtualSessions := session2.GetVirtualSessions() + for len(virtualSessions) == 0 { + time.Sleep(time.Millisecond) + virtualSessions = session2.GetVirtualSessions() + } + + virtualSession := virtualSessions[0] + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if err := client1.checkMessageJoinedSession(msg, virtualSession.PublicId(), virtualUserId); err != nil { + t.Error(err) + } + + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if msg, err := checkMessageParticipantsInCall(msg); err != nil { + t.Error(err) + } else if len(msg.Users) != 2 { + t.Errorf("Expected two users, got %+v", msg) + } else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v { + t.Errorf("Expected internal flag, got %+v", msg) + } else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg) + } else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio { + t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg) + } else if v, ok := msg.Users[1]["virtual"].(bool); !ok || !v { + t.Errorf("Expected virtual flag, got %+v", msg) + } else if v, ok := msg.Users[1]["sessionId"].(string); !ok || v != virtualSession.PublicId() { + t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) + } else if v, ok := msg.Users[1]["inCall"].(float64); !ok || v != FlagInCall|FlagWithPhone { + t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithPhone, msg) + } + + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if flags, err := checkMessageParticipantFlags(msg); err != nil { + t.Error(err) + } else if flags.RoomId != roomId { + t.Errorf("Expected room id %s, got %+v", roomId, msg) + } else if flags.SessionId != virtualSession.PublicId() { + t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) + } else if flags.Flags != FLAG_MUTED_SPEAKING { + t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, msg) + } + + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if err := client2.checkMessageJoinedSession(msg, virtualSession.PublicId(), virtualUserId); err != nil { + t.Error(err) + } + + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if msg, err := checkMessageParticipantsInCall(msg); err != nil { + t.Error(err) + } else if len(msg.Users) != 2 { + t.Errorf("Expected two users, got %+v", msg) + } else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v { + t.Errorf("Expected internal flag, got %+v", msg) + } else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg) + } else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio { + t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg) + } else if v, ok := msg.Users[1]["virtual"].(bool); !ok || !v { + t.Errorf("Expected virtual flag, got %+v", msg) + } else if v, ok := msg.Users[1]["sessionId"].(string); !ok || v != virtualSession.PublicId() { + t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) + } else if v, ok := msg.Users[1]["inCall"].(float64); !ok || v != FlagInCall|FlagWithPhone { + t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithPhone, msg) + } + + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if flags, err := checkMessageParticipantFlags(msg); err != nil { + t.Error(err) + } else if flags.RoomId != roomId { + t.Errorf("Expected room id %s, got %+v", roomId, msg) + } else if flags.SessionId != virtualSession.PublicId() { + t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) + } else if flags.Flags != FLAG_MUTED_SPEAKING { + t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, msg) + } + + updatedFlags := uint32(0) + if err := client2.SendInternalUpdateSession(&UpdateSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: virtualSessionId, + RoomId: roomId, + }, + + Flags: &updatedFlags, + }); err != nil { + t.Fatal(err) + } + + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if flags, err := checkMessageParticipantFlags(msg); err != nil { + t.Error(err) + } else if flags.RoomId != roomId { + t.Errorf("Expected room id %s, got %+v", roomId, msg) + } else if flags.SessionId != virtualSession.PublicId() { + t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) + } else if flags.Flags != 0 { + t.Errorf("Expected flags %d, got %+v", 0, msg) + } + + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if flags, err := checkMessageParticipantFlags(msg); err != nil { + t.Error(err) + } else if flags.RoomId != roomId { + t.Errorf("Expected room id %s, got %+v", roomId, msg) + } else if flags.SessionId != virtualSession.PublicId() { + t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg) + } else if flags.Flags != 0 { + t.Errorf("Expected flags %d, got %+v", 0, msg) + } + + calledCtx, calledCancel = context.WithTimeout(ctx, time.Second) + + setSessionRequestHandler(t, func(request *BackendClientSessionRequest) { + defer calledCancel() + if request.Action != "remove" { + t.Errorf("Expected action remove, got %+v", request) + } else if request.RoomId != roomId { + t.Errorf("Expected room id %s, got %+v", roomId, request) + } else if request.SessionId == generatedSessionId { + t.Errorf("Expected generated session id %s, got %+v", generatedSessionId, request) + } else if request.UserId != virtualUserId { + t.Errorf("Expected user id %s, got %+v", virtualUserId, request) + } + }) + + // Messages to virtual sessions are sent to the associated client session. + virtualRecipient := MessageClientMessageRecipient{ + Type: "session", + SessionId: virtualSession.PublicId(), + } + + data := "message-to-virtual" + client1.SendMessage(virtualRecipient, data) // nolint + + var payload string + var sender *MessageServerMessageSender + var recipient *MessageClientMessageRecipient + if err := checkReceiveClientMessageWithSenderAndRecipient(ctx, client2, "session", hello1.Hello, &payload, &sender, &recipient); err != nil { + t.Error(err) + } else if recipient.SessionId != virtualSessionId { + t.Errorf("Expected session id %s, got %+v", virtualSessionId, recipient) + } else if payload != data { + t.Errorf("Expected payload %s, got %s", data, payload) + } + + data = "control-to-virtual" + client1.SendControl(virtualRecipient, data) // nolint + + if err := checkReceiveClientControlWithSenderAndRecipient(ctx, client2, "session", hello1.Hello, &payload, &sender, &recipient); err != nil { + t.Error(err) + } else if recipient.SessionId != virtualSessionId { + t.Errorf("Expected session id %s, got %+v", virtualSessionId, recipient) + } else if payload != data { + t.Errorf("Expected payload %s, got %s", data, payload) + } + + if err := client2.SendInternalRemoveSession(&RemoveSessionInternalClientMessage{ + CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{ + SessionId: virtualSessionId, + RoomId: roomId, + }, + + UserId: virtualUserId, + }); err != nil { + t.Fatal(err) + } + <-calledCtx.Done() + if err := calledCtx.Err(); err != nil && !errors.Is(err, context.Canceled) { + t.Fatal(err) + } + + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if err := client1.checkMessageRoomLeaveSession(msg, virtualSession.PublicId()); err != nil { + t.Error(err) + } + + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else if err := client2.checkMessageRoomLeaveSession(msg, virtualSession.PublicId()); err != nil { + t.Error(err) + } + + }) + } +} diff --git a/room.go b/room.go index fbcaf7f..5c1bb4a 100644 --- a/room.go +++ b/room.go @@ -247,6 +247,9 @@ func (r *Room) processBackendRoomRequestAsyncRoom(message *AsyncRoomMessage) { switch message.Type { case "sessionjoined": r.notifySessionJoined(message.SessionId) + if message.ClientType == HelloClientTypeInternal { + r.publishUsersChangedWithInternal() + } default: log.Printf("Unsupported async room request with type %s in %s: %+v", message.Type, r.Id(), message) } @@ -305,8 +308,9 @@ func (r *Room) AddSession(session Session, sessionData *json.RawMessage) { if err := r.events.PublishBackendRoomMessage(r.id, r.backend, &AsyncMessage{ Type: "asyncroom", AsyncRoom: &AsyncRoomMessage{ - Type: "sessionjoined", - SessionId: sid, + Type: "sessionjoined", + SessionId: sid, + ClientType: session.ClientType(), }, }); err != nil { log.Printf("Error publishing joined event for session %s: %s", sid, err) @@ -452,8 +456,6 @@ func (r *Room) RemoveSession(session Session) bool { return true } - // Still need to publish an event so sessions on other servers get notified. - r.PublishSessionLeft(session) r.hub.removeRoom(r) r.statsRoomSessionsCurrent.Delete(prometheus.Labels{"clienttype": HelloClientTypeClient}) r.statsRoomSessionsCurrent.Delete(prometheus.Labels{"clienttype": HelloClientTypeInternal}) @@ -461,6 +463,8 @@ func (r *Room) RemoveSession(session Session) bool { r.unsubscribeBackend() r.doClose() r.mu.Unlock() + // Still need to publish an event so sessions on other servers get notified. + r.PublishSessionLeft(session) return false } @@ -530,10 +534,6 @@ func (r *Room) PublishSessionJoined(session Session, sessionData *RoomSessionDat if err := r.publish(message); err != nil { log.Printf("Could not publish session joined message in room %s: %s", r.Id(), err) } - - if session.ClientType() == HelloClientTypeInternal { - r.publishUsersChangedWithInternal() - } } func (r *Room) PublishSessionLeft(session Session) { @@ -564,6 +564,7 @@ func (r *Room) PublishSessionLeft(session Session) { func (r *Room) addInternalSessions(users []map[string]interface{}) []map[string]interface{} { now := time.Now().Unix() r.mu.Lock() + defer r.mu.Unlock() for _, user := range users { sessionid, found := user["sessionId"] if !found || sessionid == "" { @@ -592,7 +593,6 @@ func (r *Room) addInternalSessions(users []map[string]interface{}) []map[string] "virtual": true, }) } - r.mu.Unlock() return users } @@ -840,6 +840,10 @@ func (r *Room) NotifySessionChanged(session Session) { func (r *Room) publishUsersChangedWithInternal() { message := r.getParticipantsUpdateMessage(r.users) + if len(message.Event.Update.Users) == 0 { + return + } + if err := r.publish(message); err != nil { log.Printf("Could not publish users changed message in room %s: %s", r.Id(), err) } diff --git a/testclient_test.go b/testclient_test.go index dbe3101..f3d33ff 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -80,44 +80,36 @@ func checkUnexpectedClose(err error) error { return nil } -func toJsonString(o interface{}) string { - if s, err := json.Marshal(o); err != nil { - panic(err) - } else { - return string(s) - } -} - func checkMessageType(message *ServerMessage, expectedType string) error { if message == nil { return ErrNoMessageReceived } if message.Type != expectedType { - return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message)) + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } switch message.Type { case "hello": if message.Hello == nil { - return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message)) + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } case "message": if message.Message == nil { - return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message)) + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } else if message.Message.Data == nil || len(*message.Message.Data) == 0 { return fmt.Errorf("Received message without data") } case "room": if message.Room == nil { - return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message)) + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } case "event": if message.Event == nil { - return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message)) + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } case "transient": if message.TransientData == nil { - return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message)) + return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message) } } @@ -137,7 +129,7 @@ func checkMessageSender(hub *Hub, sender *MessageServerMessageSender, senderType return nil } -func checkReceiveClientMessageWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error { +func checkReceiveClientMessageWithSenderAndRecipient(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender, recipient **MessageClientMessageRecipient) error { message, err := client.RunUntilMessage(ctx) if err := checkUnexpectedClose(err); err != nil { return err @@ -153,14 +145,21 @@ func checkReceiveClientMessageWithSender(ctx context.Context, client *TestClient if sender != nil { *sender = message.Message.Sender } + if recipient != nil { + *recipient = message.Message.Recipient + } return nil } -func checkReceiveClientMessage(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error { - return checkReceiveClientMessageWithSender(ctx, client, senderType, hello, payload, nil) +func checkReceiveClientMessageWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error { + return checkReceiveClientMessageWithSenderAndRecipient(ctx, client, senderType, hello, payload, sender, nil) } -func checkReceiveClientControlWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error { +func checkReceiveClientMessage(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error { + return checkReceiveClientMessageWithSenderAndRecipient(ctx, client, senderType, hello, payload, nil, nil) +} + +func checkReceiveClientControlWithSenderAndRecipient(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender, recipient **MessageClientMessageRecipient) error { message, err := client.RunUntilMessage(ctx) if err := checkUnexpectedClose(err); err != nil { return err @@ -174,13 +173,20 @@ func checkReceiveClientControlWithSender(ctx context.Context, client *TestClient } } if sender != nil { - *sender = message.Message.Sender + *sender = message.Control.Sender + } + if recipient != nil { + *recipient = message.Control.Recipient } return nil } +func checkReceiveClientControlWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error { // nolint + return checkReceiveClientControlWithSenderAndRecipient(ctx, client, senderType, hello, payload, sender, nil) +} + func checkReceiveClientControl(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error { - return checkReceiveClientControlWithSender(ctx, client, senderType, hello, payload, nil) + return checkReceiveClientControlWithSenderAndRecipient(ctx, client, senderType, hello, payload, nil, nil) } func checkReceiveClientEvent(ctx context.Context, client *TestClient, eventType string, msg **EventServerMessage) error { @@ -474,6 +480,42 @@ func (c *TestClient) SendControl(recipient MessageClientMessageRecipient, data i return c.WriteJSON(message) } +func (c *TestClient) SendInternalAddSession(msg *AddSessionInternalClientMessage) error { + message := &ClientMessage{ + Id: "abcd", + Type: "internal", + Internal: &InternalClientMessage{ + Type: "addsession", + AddSession: msg, + }, + } + return c.WriteJSON(message) +} + +func (c *TestClient) SendInternalUpdateSession(msg *UpdateSessionInternalClientMessage) error { + message := &ClientMessage{ + Id: "abcd", + Type: "internal", + Internal: &InternalClientMessage{ + Type: "updatesession", + UpdateSession: msg, + }, + } + return c.WriteJSON(message) +} + +func (c *TestClient) SendInternalRemoveSession(msg *RemoveSessionInternalClientMessage) error { + message := &ClientMessage{ + Id: "abcd", + Type: "internal", + Internal: &InternalClientMessage{ + Type: "removesession", + RemoveSession: msg, + }, + } + return c.WriteJSON(message) +} + func (c *TestClient) SetTransientData(key string, value interface{}) error { payload, err := json.Marshal(value) if err != nil { @@ -672,10 +714,9 @@ func (c *TestClient) RunUntilJoinedAndReturn(ctx context.Context, hello ...*Hell if err := checkMessageType(message, "event"); err != nil { ignored = append(ignored, message) continue - } else if message.Event.Target != "room" { - return nil, nil, fmt.Errorf("Expected event target room, got %+v", message.Event) - } else if message.Event.Type != "join" { - return nil, nil, fmt.Errorf("Expected event type join, got %+v", message.Event) + } else if message.Event.Target != "room" || message.Event.Type != "join" { + ignored = append(ignored, message) + continue } for len(message.Event.Join) > 0 { diff --git a/virtualsession.go b/virtualsession.go index 87943da..45ec832 100644 --- a/virtualsession.go +++ b/virtualsession.go @@ -56,8 +56,8 @@ func GetVirtualSessionId(session *ClientSession, sessionId string) string { return session.PublicId() + "|" + sessionId } -func NewVirtualSession(session *ClientSession, privateId string, publicId string, data *SessionIdData, msg *AddSessionInternalClientMessage) *VirtualSession { - return &VirtualSession{ +func NewVirtualSession(session *ClientSession, privateId string, publicId string, data *SessionIdData, msg *AddSessionInternalClientMessage) (*VirtualSession, error) { + result := &VirtualSession{ hub: session.hub, session: session, privateId: privateId, @@ -70,6 +70,12 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string flags: msg.Flags, options: msg.Options, } + + if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil { + return nil, err + } + + return result, nil } func (s *VirtualSession) PrivateId() string { @@ -142,6 +148,7 @@ func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *Clie if removed && room != nil { go s.notifyBackendRemoved(room, session, message) } + s.session.events.UnregisterSessionListener(s.PublicId(), s.session.Backend(), s) } func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession, message *ClientMessage) { @@ -177,7 +184,10 @@ func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession return } } else { - request := NewBackendClientSessionRequest(room.Id(), "remove", s.PublicId(), nil) + request := NewBackendClientSessionRequest(room.Id(), "remove", s.PublicId(), &AddSessionInternalClientMessage{ + UserId: s.userId, + User: s.userData, + }) var response BackendClientSessionResponse err := s.hub.backend.PerformJSONRequest(ctx, s.ParsedBackendUrl(), request, &response) if err != nil { @@ -252,3 +262,36 @@ func (s *VirtualSession) Flags() uint32 { func (s *VirtualSession) Options() *AddSessionOptions { return s.options } + +func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) { + if message.Type == "message" && message.Message != nil { + switch message.Message.Type { + case "message": + if message.Message.Message != nil && + message.Message.Message.Recipient != nil && + message.Message.Message.Recipient.Type == "session" && + message.Message.Message.Recipient.SessionId == s.PublicId() { + // The client should see his session id as recipient. + message.Message.Message.Recipient = &MessageClientMessageRecipient{ + Type: "session", + SessionId: s.SessionId(), + UserId: s.UserId(), + } + s.session.ProcessAsyncSessionMessage(message) + } + case "control": + if message.Message.Control != nil && + message.Message.Control.Recipient != nil && + message.Message.Control.Recipient.Type == "session" && + message.Message.Control.Recipient.SessionId == s.PublicId() { + // The client should see his session id as recipient. + message.Message.Control.Recipient = &MessageClientMessageRecipient{ + Type: "session", + SessionId: s.SessionId(), + UserId: s.UserId(), + } + s.session.ProcessAsyncSessionMessage(message) + } + } + } +}