diff --git a/api_signaling.go b/api_signaling.go index d6729f7..895d65e 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -530,6 +530,7 @@ const ( ServerFeatureSwitchTo = "switchto" ServerFeatureDialout = "dialout" ServerFeatureFederation = "federation" + ServerFeatureRecipientCall = "recipient-call" // Features to send to internal clients only. ServerFeatureInternalVirtualSessions = "virtual-sessions" @@ -549,6 +550,7 @@ var ( ServerFeatureSwitchTo, ServerFeatureDialout, ServerFeatureFederation, + ServerFeatureRecipientCall, } DefaultFeaturesInternal = []string{ ServerFeatureInternalVirtualSessions, @@ -559,6 +561,7 @@ var ( ServerFeatureSwitchTo, ServerFeatureDialout, ServerFeatureFederation, + ServerFeatureRecipientCall, } DefaultWelcomeFeatures = []string{ ServerFeatureAudioVideoPermissions, @@ -570,6 +573,7 @@ var ( ServerFeatureSwitchTo, ServerFeatureDialout, ServerFeatureFederation, + ServerFeatureRecipientCall, } ) @@ -671,6 +675,7 @@ const ( RecipientTypeSession = "session" RecipientTypeUser = "user" RecipientTypeRoom = "room" + RecipientTypeCall = "call" ) type MessageClientMessageRecipient struct { @@ -740,6 +745,8 @@ func (m *MessageClientMessage) CheckValid() error { } switch m.Recipient.Type { case RecipientTypeRoom: + fallthrough + case RecipientTypeCall: // No additional checks required. case RecipientTypeSession: if m.Recipient.SessionId == "" { diff --git a/clientsession.go b/clientsession.go index 0d3f347..c59920a 100644 --- a/clientsession.go +++ b/clientsession.go @@ -1362,18 +1362,34 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage { switch msg.Message.Type { case "message": - if msg.Message.Message != nil && - msg.Message.Message.Sender != nil && - msg.Message.Message.Sender.SessionId == s.PublicId() { - // Don't send message back to sender (can happen if sent to user or room) - return nil + if msg.Message.Message != nil { + if sender := msg.Message.Message.Sender; sender != nil { + if sender.SessionId == s.PublicId() { + // Don't send message back to sender (can happen if sent to user or room) + return nil + } + if sender.Type == RecipientTypeCall { + if room := s.GetRoom(); room == nil || !room.IsSessionInCall(s) { + // Session is not in call, so discard. + return nil + } + } + } } case "control": - if msg.Message.Control != nil && - msg.Message.Control.Sender != nil && - msg.Message.Control.Sender.SessionId == s.PublicId() { - // Don't send message back to sender (can happen if sent to user or room) - return nil + if msg.Message.Control != nil { + if sender := msg.Message.Control.Sender; sender != nil { + if sender.SessionId == s.PublicId() { + // Don't send message back to sender (can happen if sent to user or room) + return nil + } + if sender.Type == RecipientTypeCall { + if room := s.GetRoom(); room == nil || !room.IsSessionInCall(s) { + // Session is not in call, so discard. + return nil + } + } + } } case "event": if msg.Message.Event.Target == "room" { diff --git a/hub.go b/hub.go index 7c3e78a..747b435 100644 --- a/hub.go +++ b/hub.go @@ -2010,6 +2010,8 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { subject = GetSubjectForUserId(msg.Recipient.UserId, session.Backend()) } case RecipientTypeRoom: + fallthrough + case RecipientTypeCall: if session != nil { if room = session.GetRoom(); room != nil { subject = GetSubjectForRoomId(room.Id(), room.Backend()) @@ -2130,6 +2132,8 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { case RecipientTypeUser: err = h.events.PublishUserMessage(msg.Recipient.UserId, session.Backend(), async) case RecipientTypeRoom: + fallthrough + case RecipientTypeCall: err = h.events.PublishRoomMessage(room.Id(), session.Backend(), async) default: err = fmt.Errorf("unsupported recipient type: %s", msg.Recipient.Type) @@ -2217,6 +2221,8 @@ func (h *Hub) processControlMsg(session Session, message *ClientMessage) { subject = GetSubjectForUserId(msg.Recipient.UserId, session.Backend()) } case RecipientTypeRoom: + fallthrough + case RecipientTypeCall: if session != nil { if room = session.GetRoom(); room != nil { subject = GetSubjectForRoomId(room.Id(), room.Backend()) @@ -2254,6 +2260,8 @@ func (h *Hub) processControlMsg(session Session, message *ClientMessage) { case RecipientTypeUser: err = h.events.PublishUserMessage(msg.Recipient.UserId, session.Backend(), async) case RecipientTypeRoom: + fallthrough + case RecipientTypeCall: err = h.events.PublishRoomMessage(room.Id(), room.Backend(), async) default: err = fmt.Errorf("unsupported recipient type: %s", msg.Recipient.Type) diff --git a/hub_test.go b/hub_test.go index f608149..c9cb30d 100644 --- a/hub_test.go +++ b/hub_test.go @@ -2525,6 +2525,248 @@ func TestClientControlToRoom(t *testing.T) { } } +func TestClientMessageToCall(t *testing.T) { + CatchLogForTest(t) + for _, subtest := range clusteredTests { + t.Run(subtest, func(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + var hub1 *Hub + var hub2 *Hub + var server1 *httptest.Server + var server2 *httptest.Server + + if isLocalTest(t) { + hub1, _, _, server1 = CreateHubForTest(t) + + hub2 = hub1 + server2 = server1 + } else { + hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewTestClient(t, server1, hub1) + defer client1.CloseWithBye() + require.NoError(client1.SendHello(testDefaultUserId + "1")) + hello1, err := client1.RunUntilHello(ctx) + require.NoError(err) + + client2 := NewTestClient(t, server2, hub2) + defer client2.CloseWithBye() + require.NoError(client2.SendHello(testDefaultUserId + "2")) + hello2, err := client2.RunUntilHello(ctx) + require.NoError(err) + + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) + + // Join room by id. + roomId := "test-room" + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + + // Give message processing some time. + time.Sleep(10 * time.Millisecond) + + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + + WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) + + // Simulate request from the backend that somebody joined the call. + users := []map[string]interface{}{ + { + "sessionId": hello1.Hello.SessionId, + "inCall": 1, + }, + } + room1 := hub1.getRoom(roomId) + require.NotNil(room1, "Could not find room %s", roomId) + room1.PublishUsersInCallChanged(users, users) + assert.NoError(checkReceiveClientEvent(ctx, client1, "update", nil)) + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) + + recipient := MessageClientMessageRecipient{ + Type: "call", + } + + data1 := "from-1-to-2" + client1.SendMessage(recipient, data1) // nolint + data2 := "from-2-to-1" + client2.SendMessage(recipient, data2) // nolint + + var payload string + if err := checkReceiveClientMessage(ctx, client1, "call", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) + } + + // The second client is not in the call yet, so will not receive the message. + ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel2() + + if message, err := client2.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message", "got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) + } + + // Simulate request from the backend that somebody joined the call. + users = []map[string]interface{}{ + { + "sessionId": hello1.Hello.SessionId, + "inCall": 1, + }, + { + "sessionId": hello2.Hello.SessionId, + "inCall": 1, + }, + } + room2 := hub2.getRoom(roomId) + require.NotNil(room2, "Could not find room %s", roomId) + room2.PublishUsersInCallChanged(users, users) + assert.NoError(checkReceiveClientEvent(ctx, client1, "update", nil)) + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) + + client1.SendMessage(recipient, data1) // nolint + client2.SendMessage(recipient, data2) // nolint + + if err := checkReceiveClientMessage(ctx, client1, "call", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) + } + if err := checkReceiveClientMessage(ctx, client2, "call", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) + } + }) + } +} + +func TestClientControlToCall(t *testing.T) { + CatchLogForTest(t) + for _, subtest := range clusteredTests { + t.Run(subtest, func(t *testing.T) { + t.Parallel() + require := require.New(t) + assert := assert.New(t) + var hub1 *Hub + var hub2 *Hub + var server1 *httptest.Server + var server2 *httptest.Server + + if isLocalTest(t) { + hub1, _, _, server1 = CreateHubForTest(t) + + hub2 = hub1 + server2 = server1 + } else { + hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewTestClient(t, server1, hub1) + defer client1.CloseWithBye() + require.NoError(client1.SendHello(testDefaultUserId + "1")) + hello1, err := client1.RunUntilHello(ctx) + require.NoError(err) + + client2 := NewTestClient(t, server2, hub2) + defer client2.CloseWithBye() + require.NoError(client2.SendHello(testDefaultUserId + "2")) + hello2, err := client2.RunUntilHello(ctx) + require.NoError(err) + + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) + + // Join room by id. + roomId := "test-room" + roomMsg, err := client1.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + + // Give message processing some time. + time.Sleep(10 * time.Millisecond) + + roomMsg, err = client2.JoinRoom(ctx, roomId) + require.NoError(err) + require.Equal(roomId, roomMsg.Room.RoomId) + + WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) + + // Simulate request from the backend that somebody joined the call. + users := []map[string]interface{}{ + { + "sessionId": hello1.Hello.SessionId, + "inCall": 1, + }, + } + room1 := hub1.getRoom(roomId) + require.NotNil(room1, "Could not find room %s", roomId) + room1.PublishUsersInCallChanged(users, users) + assert.NoError(checkReceiveClientEvent(ctx, client1, "update", nil)) + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) + + recipient := MessageClientMessageRecipient{ + Type: "call", + } + + data1 := "from-1-to-2" + client1.SendControl(recipient, data1) // nolint + data2 := "from-2-to-1" + client2.SendControl(recipient, data2) // nolint + + var payload string + if err := checkReceiveClientControl(ctx, client1, "call", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) + } + + // The second client is not in the call yet, so will not receive the message. + ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel2() + + if message, err := client2.RunUntilMessage(ctx2); err == nil { + assert.Fail("Expected no message", "got %+v", message) + } else if err != ErrNoMessageReceived && err != context.DeadlineExceeded { + assert.NoError(err) + } + + // Simulate request from the backend that somebody joined the call. + users = []map[string]interface{}{ + { + "sessionId": hello1.Hello.SessionId, + "inCall": 1, + }, + { + "sessionId": hello2.Hello.SessionId, + "inCall": 1, + }, + } + room2 := hub2.getRoom(roomId) + require.NotNil(room2, "Could not find room %s", roomId) + room2.PublishUsersInCallChanged(users, users) + assert.NoError(checkReceiveClientEvent(ctx, client1, "update", nil)) + assert.NoError(checkReceiveClientEvent(ctx, client2, "update", nil)) + + client1.SendControl(recipient, data1) // nolint + client2.SendControl(recipient, data2) // nolint + + if err := checkReceiveClientControl(ctx, client1, "call", hello2.Hello, &payload); assert.NoError(err) { + assert.Equal(data2, payload) + } + if err := checkReceiveClientControl(ctx, client2, "call", hello1.Hello, &payload); assert.NoError(err) { + assert.Equal(data1, payload) + } + }) + } +} + func TestJoinRoom(t *testing.T) { t.Parallel() CatchLogForTest(t)