mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-05-19 14:06:32 +02:00
Merge pull request #295 from strukturag/virtualsessions-tests
Add tests for virtual sessions.
This commit is contained in:
commit
12a8fa98d0
|
@ -44,7 +44,8 @@ type AsyncMessage struct {
|
|||
type AsyncRoomMessage struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
SessionId string `json:"sessionid,omitempty"`
|
||||
ClientType string `json:"clienttype,omitempty"`
|
||||
}
|
||||
|
||||
type SendOfferMessage struct {
|
||||
|
|
|
@ -661,6 +661,14 @@ type RoomEventServerMessage struct {
|
|||
All bool `json:"all,omitempty"`
|
||||
}
|
||||
|
||||
func (m *RoomEventServerMessage) String() string {
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not serialize %#v: %s", m, err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
const (
|
||||
DisinviteReasonDisinvited = "disinvited"
|
||||
DisinviteReasonDeleted = "deleted"
|
||||
|
@ -714,6 +722,14 @@ type EventServerMessage struct {
|
|||
Message *RoomEventMessage `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func (m *EventServerMessage) String() string {
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not serialize %#v: %s", m, err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
type EventServerMessageSessionEntry struct {
|
||||
SessionId string `json:"sessionid"`
|
||||
UserId string `json:"userid"`
|
||||
|
|
|
@ -1316,3 +1316,14 @@ func (s *ClientSession) RemoveVirtualSession(session *VirtualSession) {
|
|||
delete(s.virtualSessions, session)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *ClientSession) GetVirtualSessions() []*VirtualSession {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
result := make([]*VirtualSession, 0, len(s.virtualSessions))
|
||||
for session := range s.virtualSessions {
|
||||
result = append(result, session)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
17
hub.go
17
hub.go
|
@ -1306,6 +1306,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
|
|||
} else {
|
||||
subject = "session." + msg.Recipient.SessionId
|
||||
recipientSessionId = msg.Recipient.SessionId
|
||||
serverRecipient = &msg.Recipient
|
||||
}
|
||||
case RecipientTypeUser:
|
||||
if msg.Recipient.UserId != "" {
|
||||
|
@ -1501,8 +1502,12 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
|
|||
SessionId: virtualSession.SessionId(),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
serverRecipient = &msg.Recipient
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
} else {
|
||||
serverRecipient = &msg.Recipient
|
||||
}
|
||||
case RecipientTypeUser:
|
||||
if msg.Recipient.UserId != "" {
|
||||
|
@ -1600,6 +1605,14 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) {
|
|||
|
||||
virtualSessionId := GetVirtualSessionId(session, msg.SessionId)
|
||||
|
||||
sess, err := NewVirtualSession(session, privateSessionId, publicSessionId, sessionIdData, msg)
|
||||
if err != nil {
|
||||
log.Printf("Could not create virtual session %s: %s", virtualSessionId, err)
|
||||
reply := message.NewErrorServerMessage(NewError("add_failed", "Could not create virtual session."))
|
||||
session.SendMessage(reply)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.Options != nil {
|
||||
request := NewBackendClientRoomRequest(room.Id(), msg.UserId, publicSessionId)
|
||||
request.Room.ActorId = msg.Options.ActorId
|
||||
|
@ -1608,6 +1621,7 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) {
|
|||
|
||||
var response BackendClientResponse
|
||||
if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendUrl(), request, &response); err != nil {
|
||||
sess.Close()
|
||||
log.Printf("Could not join virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err)
|
||||
reply := message.NewErrorServerMessage(NewError("add_failed", "Could not join virtual session."))
|
||||
session.SendMessage(reply)
|
||||
|
@ -1615,6 +1629,7 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) {
|
|||
}
|
||||
|
||||
if response.Type == "error" {
|
||||
sess.Close()
|
||||
log.Printf("Could not join virtual session %s at backend %s: %+v", virtualSessionId, session.BackendUrl(), response.Error)
|
||||
reply := message.NewErrorServerMessage(NewError("add_failed", response.Error.Error()))
|
||||
session.SendMessage(reply)
|
||||
|
@ -1624,6 +1639,7 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) {
|
|||
request := NewBackendClientSessionRequest(room.Id(), "add", publicSessionId, msg)
|
||||
var response BackendClientSessionResponse
|
||||
if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendUrl(), request, &response); err != nil {
|
||||
sess.Close()
|
||||
log.Printf("Could not add virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err)
|
||||
reply := message.NewErrorServerMessage(NewError("add_failed", "Could not add virtual session."))
|
||||
session.SendMessage(reply)
|
||||
|
@ -1631,7 +1647,6 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) {
|
|||
}
|
||||
}
|
||||
|
||||
sess := NewVirtualSession(session, privateSessionId, publicSessionId, sessionIdData, msg)
|
||||
h.mu.Lock()
|
||||
h.sessions[sessionIdData.Sid] = sess
|
||||
h.virtualSessions[virtualSessionId] = sessionIdData.Sid
|
||||
|
|
375
hub_test.go
375
hub_test.go
|
@ -24,6 +24,7 @@ package signaling
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -404,12 +405,47 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re
|
|||
return response
|
||||
}
|
||||
|
||||
var (
|
||||
sessionRequestHander struct {
|
||||
sync.Mutex
|
||||
handlers map[*testing.T]func(*BackendClientSessionRequest)
|
||||
}
|
||||
)
|
||||
|
||||
func setSessionRequestHandler(t *testing.T, f func(*BackendClientSessionRequest)) {
|
||||
sessionRequestHander.Lock()
|
||||
defer sessionRequestHander.Unlock()
|
||||
if sessionRequestHander.handlers == nil {
|
||||
sessionRequestHander.handlers = make(map[*testing.T]func(*BackendClientSessionRequest))
|
||||
}
|
||||
if _, found := sessionRequestHander.handlers[t]; !found {
|
||||
t.Cleanup(func() {
|
||||
sessionRequestHander.Lock()
|
||||
defer sessionRequestHander.Unlock()
|
||||
|
||||
delete(sessionRequestHander.handlers, t)
|
||||
})
|
||||
}
|
||||
sessionRequestHander.handlers[t] = f
|
||||
}
|
||||
|
||||
func clearSessionRequestHandler(t *testing.T) { // nolint
|
||||
sessionRequestHander.Lock()
|
||||
defer sessionRequestHander.Unlock()
|
||||
|
||||
delete(sessionRequestHander.handlers, t)
|
||||
}
|
||||
|
||||
func processSessionRequest(t *testing.T, w http.ResponseWriter, r *http.Request, request *BackendClientRequest) *BackendClientResponse {
|
||||
if request.Type != "session" || request.Session == nil {
|
||||
t.Fatalf("Expected an session backend request, got %+v", request)
|
||||
}
|
||||
|
||||
// TODO(jojo): Evaluate request.
|
||||
sessionRequestHander.Lock()
|
||||
defer sessionRequestHander.Unlock()
|
||||
if f, found := sessionRequestHander.handlers[t]; found {
|
||||
f(request.Session)
|
||||
}
|
||||
|
||||
response := &BackendClientResponse{
|
||||
Type: "session",
|
||||
|
@ -4134,3 +4170,340 @@ func TestClientUnshareScreen(t *testing.T) {
|
|||
t.Fatalf("Publisher %s should be closed", hello1.Hello.SessionId)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVirtualClientSessions(t *testing.T) {
|
||||
for _, subtest := range clusteredTests {
|
||||
t.Run(subtest, func(t *testing.T) {
|
||||
var hub1 *Hub
|
||||
var hub2 *Hub
|
||||
var server1 *httptest.Server
|
||||
var server2 *httptest.Server
|
||||
if isLocalTest(t) {
|
||||
hub1, _, _, server1 = CreateHubForTest(t)
|
||||
|
||||
hub2 = hub1
|
||||
server2 = server1
|
||||
} else {
|
||||
hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
client1 := NewTestClient(t, server1, hub1)
|
||||
defer client1.CloseWithBye()
|
||||
|
||||
if err := client1.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hello1, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
roomId := "test-room"
|
||||
if _, err := client1.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
client2 := NewTestClient(t, server2, hub2)
|
||||
defer client2.CloseWithBye()
|
||||
|
||||
if err := client2.SendHelloInternal(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
session2 := hub2.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession)
|
||||
if session2 == nil {
|
||||
t.Fatalf("Session %s does not exist", hello2.Hello.SessionId)
|
||||
}
|
||||
|
||||
if _, err := client2.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := client1.RunUntilJoined(ctx, hello2.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if msg, err := checkMessageParticipantsInCall(msg); err != nil {
|
||||
t.Error(err)
|
||||
} else if len(msg.Users) != 1 {
|
||||
t.Errorf("Expected one user, got %+v", msg)
|
||||
} else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v {
|
||||
t.Errorf("Expected internal flag, got %+v", msg)
|
||||
} else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg)
|
||||
} else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != 3 {
|
||||
t.Errorf("Expected inCall flag 3, got %+v", msg)
|
||||
}
|
||||
|
||||
_, unexpected, err := client2.RunUntilJoinedAndReturn(ctx, hello1.Hello, hello2.Hello)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(unexpected) == 0 {
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
unexpected = append(unexpected, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if len(unexpected) != 1 {
|
||||
t.Fatalf("expected one message, got %+v", unexpected)
|
||||
}
|
||||
|
||||
if msg, err := checkMessageParticipantsInCall(unexpected[0]); err != nil {
|
||||
t.Error(err)
|
||||
} else if len(msg.Users) != 1 {
|
||||
t.Errorf("Expected one user, got %+v", msg)
|
||||
} else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v {
|
||||
t.Errorf("Expected internal flag, got %+v", msg)
|
||||
} else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg)
|
||||
} else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio {
|
||||
t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg)
|
||||
}
|
||||
|
||||
calledCtx, calledCancel := context.WithTimeout(ctx, time.Second)
|
||||
|
||||
virtualSessionId := "virtual-session-id"
|
||||
virtualUserId := "virtual-user-id"
|
||||
generatedSessionId := GetVirtualSessionId(session2, virtualSessionId)
|
||||
|
||||
setSessionRequestHandler(t, func(request *BackendClientSessionRequest) {
|
||||
defer calledCancel()
|
||||
if request.Action != "add" {
|
||||
t.Errorf("Expected action add, got %+v", request)
|
||||
} else if request.RoomId != roomId {
|
||||
t.Errorf("Expected room id %s, got %+v", roomId, request)
|
||||
} else if request.SessionId == generatedSessionId {
|
||||
t.Errorf("Expected generated session id %s, got %+v", generatedSessionId, request)
|
||||
} else if request.UserId != virtualUserId {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualUserId, request)
|
||||
}
|
||||
})
|
||||
|
||||
if err := client2.SendInternalAddSession(&AddSessionInternalClientMessage{
|
||||
CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{
|
||||
SessionId: virtualSessionId,
|
||||
RoomId: roomId,
|
||||
},
|
||||
UserId: virtualUserId,
|
||||
Flags: FLAG_MUTED_SPEAKING,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
<-calledCtx.Done()
|
||||
if err := calledCtx.Err(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
virtualSessions := session2.GetVirtualSessions()
|
||||
for len(virtualSessions) == 0 {
|
||||
time.Sleep(time.Millisecond)
|
||||
virtualSessions = session2.GetVirtualSessions()
|
||||
}
|
||||
|
||||
virtualSession := virtualSessions[0]
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if err := client1.checkMessageJoinedSession(msg, virtualSession.PublicId(), virtualUserId); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if msg, err := checkMessageParticipantsInCall(msg); err != nil {
|
||||
t.Error(err)
|
||||
} else if len(msg.Users) != 2 {
|
||||
t.Errorf("Expected two users, got %+v", msg)
|
||||
} else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v {
|
||||
t.Errorf("Expected internal flag, got %+v", msg)
|
||||
} else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg)
|
||||
} else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio {
|
||||
t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg)
|
||||
} else if v, ok := msg.Users[1]["virtual"].(bool); !ok || !v {
|
||||
t.Errorf("Expected virtual flag, got %+v", msg)
|
||||
} else if v, ok := msg.Users[1]["sessionId"].(string); !ok || v != virtualSession.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg)
|
||||
} else if v, ok := msg.Users[1]["inCall"].(float64); !ok || v != FlagInCall|FlagWithPhone {
|
||||
t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithPhone, msg)
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if flags, err := checkMessageParticipantFlags(msg); err != nil {
|
||||
t.Error(err)
|
||||
} else if flags.RoomId != roomId {
|
||||
t.Errorf("Expected room id %s, got %+v", roomId, msg)
|
||||
} else if flags.SessionId != virtualSession.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg)
|
||||
} else if flags.Flags != FLAG_MUTED_SPEAKING {
|
||||
t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, msg)
|
||||
}
|
||||
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if err := client2.checkMessageJoinedSession(msg, virtualSession.PublicId(), virtualUserId); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if msg, err := checkMessageParticipantsInCall(msg); err != nil {
|
||||
t.Error(err)
|
||||
} else if len(msg.Users) != 2 {
|
||||
t.Errorf("Expected two users, got %+v", msg)
|
||||
} else if v, ok := msg.Users[0]["internal"].(bool); !ok || !v {
|
||||
t.Errorf("Expected internal flag, got %+v", msg)
|
||||
} else if v, ok := msg.Users[0]["sessionId"].(string); !ok || v != hello2.Hello.SessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", hello2.Hello.SessionId, msg)
|
||||
} else if v, ok := msg.Users[0]["inCall"].(float64); !ok || v != FlagInCall|FlagWithAudio {
|
||||
t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithAudio, msg)
|
||||
} else if v, ok := msg.Users[1]["virtual"].(bool); !ok || !v {
|
||||
t.Errorf("Expected virtual flag, got %+v", msg)
|
||||
} else if v, ok := msg.Users[1]["sessionId"].(string); !ok || v != virtualSession.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg)
|
||||
} else if v, ok := msg.Users[1]["inCall"].(float64); !ok || v != FlagInCall|FlagWithPhone {
|
||||
t.Errorf("Expected inCall flag %d, got %+v", FlagInCall|FlagWithPhone, msg)
|
||||
}
|
||||
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if flags, err := checkMessageParticipantFlags(msg); err != nil {
|
||||
t.Error(err)
|
||||
} else if flags.RoomId != roomId {
|
||||
t.Errorf("Expected room id %s, got %+v", roomId, msg)
|
||||
} else if flags.SessionId != virtualSession.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg)
|
||||
} else if flags.Flags != FLAG_MUTED_SPEAKING {
|
||||
t.Errorf("Expected flags %d, got %+v", FLAG_MUTED_SPEAKING, msg)
|
||||
}
|
||||
|
||||
updatedFlags := uint32(0)
|
||||
if err := client2.SendInternalUpdateSession(&UpdateSessionInternalClientMessage{
|
||||
CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{
|
||||
SessionId: virtualSessionId,
|
||||
RoomId: roomId,
|
||||
},
|
||||
|
||||
Flags: &updatedFlags,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if flags, err := checkMessageParticipantFlags(msg); err != nil {
|
||||
t.Error(err)
|
||||
} else if flags.RoomId != roomId {
|
||||
t.Errorf("Expected room id %s, got %+v", roomId, msg)
|
||||
} else if flags.SessionId != virtualSession.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg)
|
||||
} else if flags.Flags != 0 {
|
||||
t.Errorf("Expected flags %d, got %+v", 0, msg)
|
||||
}
|
||||
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if flags, err := checkMessageParticipantFlags(msg); err != nil {
|
||||
t.Error(err)
|
||||
} else if flags.RoomId != roomId {
|
||||
t.Errorf("Expected room id %s, got %+v", roomId, msg)
|
||||
} else if flags.SessionId != virtualSession.PublicId() {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualSession.PublicId(), msg)
|
||||
} else if flags.Flags != 0 {
|
||||
t.Errorf("Expected flags %d, got %+v", 0, msg)
|
||||
}
|
||||
|
||||
calledCtx, calledCancel = context.WithTimeout(ctx, time.Second)
|
||||
|
||||
setSessionRequestHandler(t, func(request *BackendClientSessionRequest) {
|
||||
defer calledCancel()
|
||||
if request.Action != "remove" {
|
||||
t.Errorf("Expected action remove, got %+v", request)
|
||||
} else if request.RoomId != roomId {
|
||||
t.Errorf("Expected room id %s, got %+v", roomId, request)
|
||||
} else if request.SessionId == generatedSessionId {
|
||||
t.Errorf("Expected generated session id %s, got %+v", generatedSessionId, request)
|
||||
} else if request.UserId != virtualUserId {
|
||||
t.Errorf("Expected user id %s, got %+v", virtualUserId, request)
|
||||
}
|
||||
})
|
||||
|
||||
// Messages to virtual sessions are sent to the associated client session.
|
||||
virtualRecipient := MessageClientMessageRecipient{
|
||||
Type: "session",
|
||||
SessionId: virtualSession.PublicId(),
|
||||
}
|
||||
|
||||
data := "message-to-virtual"
|
||||
client1.SendMessage(virtualRecipient, data) // nolint
|
||||
|
||||
var payload string
|
||||
var sender *MessageServerMessageSender
|
||||
var recipient *MessageClientMessageRecipient
|
||||
if err := checkReceiveClientMessageWithSenderAndRecipient(ctx, client2, "session", hello1.Hello, &payload, &sender, &recipient); err != nil {
|
||||
t.Error(err)
|
||||
} else if recipient.SessionId != virtualSessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualSessionId, recipient)
|
||||
} else if payload != data {
|
||||
t.Errorf("Expected payload %s, got %s", data, payload)
|
||||
}
|
||||
|
||||
data = "control-to-virtual"
|
||||
client1.SendControl(virtualRecipient, data) // nolint
|
||||
|
||||
if err := checkReceiveClientControlWithSenderAndRecipient(ctx, client2, "session", hello1.Hello, &payload, &sender, &recipient); err != nil {
|
||||
t.Error(err)
|
||||
} else if recipient.SessionId != virtualSessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", virtualSessionId, recipient)
|
||||
} else if payload != data {
|
||||
t.Errorf("Expected payload %s, got %s", data, payload)
|
||||
}
|
||||
|
||||
if err := client2.SendInternalRemoveSession(&RemoveSessionInternalClientMessage{
|
||||
CommonSessionInternalClientMessage: CommonSessionInternalClientMessage{
|
||||
SessionId: virtualSessionId,
|
||||
RoomId: roomId,
|
||||
},
|
||||
|
||||
UserId: virtualUserId,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
<-calledCtx.Done()
|
||||
if err := calledCtx.Err(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if err := client1.checkMessageRoomLeaveSession(msg, virtualSession.PublicId()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else if err := client2.checkMessageRoomLeaveSession(msg, virtualSession.PublicId()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
22
room.go
22
room.go
|
@ -247,6 +247,9 @@ func (r *Room) processBackendRoomRequestAsyncRoom(message *AsyncRoomMessage) {
|
|||
switch message.Type {
|
||||
case "sessionjoined":
|
||||
r.notifySessionJoined(message.SessionId)
|
||||
if message.ClientType == HelloClientTypeInternal {
|
||||
r.publishUsersChangedWithInternal()
|
||||
}
|
||||
default:
|
||||
log.Printf("Unsupported async room request with type %s in %s: %+v", message.Type, r.Id(), message)
|
||||
}
|
||||
|
@ -305,8 +308,9 @@ func (r *Room) AddSession(session Session, sessionData *json.RawMessage) {
|
|||
if err := r.events.PublishBackendRoomMessage(r.id, r.backend, &AsyncMessage{
|
||||
Type: "asyncroom",
|
||||
AsyncRoom: &AsyncRoomMessage{
|
||||
Type: "sessionjoined",
|
||||
SessionId: sid,
|
||||
Type: "sessionjoined",
|
||||
SessionId: sid,
|
||||
ClientType: session.ClientType(),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Printf("Error publishing joined event for session %s: %s", sid, err)
|
||||
|
@ -452,8 +456,6 @@ func (r *Room) RemoveSession(session Session) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// Still need to publish an event so sessions on other servers get notified.
|
||||
r.PublishSessionLeft(session)
|
||||
r.hub.removeRoom(r)
|
||||
r.statsRoomSessionsCurrent.Delete(prometheus.Labels{"clienttype": HelloClientTypeClient})
|
||||
r.statsRoomSessionsCurrent.Delete(prometheus.Labels{"clienttype": HelloClientTypeInternal})
|
||||
|
@ -461,6 +463,8 @@ func (r *Room) RemoveSession(session Session) bool {
|
|||
r.unsubscribeBackend()
|
||||
r.doClose()
|
||||
r.mu.Unlock()
|
||||
// Still need to publish an event so sessions on other servers get notified.
|
||||
r.PublishSessionLeft(session)
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -530,10 +534,6 @@ func (r *Room) PublishSessionJoined(session Session, sessionData *RoomSessionDat
|
|||
if err := r.publish(message); err != nil {
|
||||
log.Printf("Could not publish session joined message in room %s: %s", r.Id(), err)
|
||||
}
|
||||
|
||||
if session.ClientType() == HelloClientTypeInternal {
|
||||
r.publishUsersChangedWithInternal()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Room) PublishSessionLeft(session Session) {
|
||||
|
@ -564,6 +564,7 @@ func (r *Room) PublishSessionLeft(session Session) {
|
|||
func (r *Room) addInternalSessions(users []map[string]interface{}) []map[string]interface{} {
|
||||
now := time.Now().Unix()
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, user := range users {
|
||||
sessionid, found := user["sessionId"]
|
||||
if !found || sessionid == "" {
|
||||
|
@ -592,7 +593,6 @@ func (r *Room) addInternalSessions(users []map[string]interface{}) []map[string]
|
|||
"virtual": true,
|
||||
})
|
||||
}
|
||||
r.mu.Unlock()
|
||||
return users
|
||||
}
|
||||
|
||||
|
@ -840,6 +840,10 @@ func (r *Room) NotifySessionChanged(session Session) {
|
|||
|
||||
func (r *Room) publishUsersChangedWithInternal() {
|
||||
message := r.getParticipantsUpdateMessage(r.users)
|
||||
if len(message.Event.Update.Users) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.publish(message); err != nil {
|
||||
log.Printf("Could not publish users changed message in room %s: %s", r.Id(), err)
|
||||
}
|
||||
|
|
|
@ -80,44 +80,36 @@ func checkUnexpectedClose(err error) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func toJsonString(o interface{}) string {
|
||||
if s, err := json.Marshal(o); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
return string(s)
|
||||
}
|
||||
}
|
||||
|
||||
func checkMessageType(message *ServerMessage, expectedType string) error {
|
||||
if message == nil {
|
||||
return ErrNoMessageReceived
|
||||
}
|
||||
|
||||
if message.Type != expectedType {
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message))
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message)
|
||||
}
|
||||
switch message.Type {
|
||||
case "hello":
|
||||
if message.Hello == nil {
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message))
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message)
|
||||
}
|
||||
case "message":
|
||||
if message.Message == nil {
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message))
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message)
|
||||
} else if message.Message.Data == nil || len(*message.Message.Data) == 0 {
|
||||
return fmt.Errorf("Received message without data")
|
||||
}
|
||||
case "room":
|
||||
if message.Room == nil {
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message))
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message)
|
||||
}
|
||||
case "event":
|
||||
if message.Event == nil {
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message))
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message)
|
||||
}
|
||||
case "transient":
|
||||
if message.TransientData == nil {
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v (%s)", expectedType, message, toJsonString(message))
|
||||
return fmt.Errorf("Expected \"%s\" message, got %+v", expectedType, message)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -137,7 +129,7 @@ func checkMessageSender(hub *Hub, sender *MessageServerMessageSender, senderType
|
|||
return nil
|
||||
}
|
||||
|
||||
func checkReceiveClientMessageWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error {
|
||||
func checkReceiveClientMessageWithSenderAndRecipient(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender, recipient **MessageClientMessageRecipient) error {
|
||||
message, err := client.RunUntilMessage(ctx)
|
||||
if err := checkUnexpectedClose(err); err != nil {
|
||||
return err
|
||||
|
@ -153,14 +145,21 @@ func checkReceiveClientMessageWithSender(ctx context.Context, client *TestClient
|
|||
if sender != nil {
|
||||
*sender = message.Message.Sender
|
||||
}
|
||||
if recipient != nil {
|
||||
*recipient = message.Message.Recipient
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkReceiveClientMessage(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error {
|
||||
return checkReceiveClientMessageWithSender(ctx, client, senderType, hello, payload, nil)
|
||||
func checkReceiveClientMessageWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error {
|
||||
return checkReceiveClientMessageWithSenderAndRecipient(ctx, client, senderType, hello, payload, sender, nil)
|
||||
}
|
||||
|
||||
func checkReceiveClientControlWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error {
|
||||
func checkReceiveClientMessage(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error {
|
||||
return checkReceiveClientMessageWithSenderAndRecipient(ctx, client, senderType, hello, payload, nil, nil)
|
||||
}
|
||||
|
||||
func checkReceiveClientControlWithSenderAndRecipient(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender, recipient **MessageClientMessageRecipient) error {
|
||||
message, err := client.RunUntilMessage(ctx)
|
||||
if err := checkUnexpectedClose(err); err != nil {
|
||||
return err
|
||||
|
@ -174,13 +173,20 @@ func checkReceiveClientControlWithSender(ctx context.Context, client *TestClient
|
|||
}
|
||||
}
|
||||
if sender != nil {
|
||||
*sender = message.Message.Sender
|
||||
*sender = message.Control.Sender
|
||||
}
|
||||
if recipient != nil {
|
||||
*recipient = message.Control.Recipient
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkReceiveClientControlWithSender(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}, sender **MessageServerMessageSender) error { // nolint
|
||||
return checkReceiveClientControlWithSenderAndRecipient(ctx, client, senderType, hello, payload, sender, nil)
|
||||
}
|
||||
|
||||
func checkReceiveClientControl(ctx context.Context, client *TestClient, senderType string, hello *HelloServerMessage, payload interface{}) error {
|
||||
return checkReceiveClientControlWithSender(ctx, client, senderType, hello, payload, nil)
|
||||
return checkReceiveClientControlWithSenderAndRecipient(ctx, client, senderType, hello, payload, nil, nil)
|
||||
}
|
||||
|
||||
func checkReceiveClientEvent(ctx context.Context, client *TestClient, eventType string, msg **EventServerMessage) error {
|
||||
|
@ -474,6 +480,42 @@ func (c *TestClient) SendControl(recipient MessageClientMessageRecipient, data i
|
|||
return c.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (c *TestClient) SendInternalAddSession(msg *AddSessionInternalClientMessage) error {
|
||||
message := &ClientMessage{
|
||||
Id: "abcd",
|
||||
Type: "internal",
|
||||
Internal: &InternalClientMessage{
|
||||
Type: "addsession",
|
||||
AddSession: msg,
|
||||
},
|
||||
}
|
||||
return c.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (c *TestClient) SendInternalUpdateSession(msg *UpdateSessionInternalClientMessage) error {
|
||||
message := &ClientMessage{
|
||||
Id: "abcd",
|
||||
Type: "internal",
|
||||
Internal: &InternalClientMessage{
|
||||
Type: "updatesession",
|
||||
UpdateSession: msg,
|
||||
},
|
||||
}
|
||||
return c.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (c *TestClient) SendInternalRemoveSession(msg *RemoveSessionInternalClientMessage) error {
|
||||
message := &ClientMessage{
|
||||
Id: "abcd",
|
||||
Type: "internal",
|
||||
Internal: &InternalClientMessage{
|
||||
Type: "removesession",
|
||||
RemoveSession: msg,
|
||||
},
|
||||
}
|
||||
return c.WriteJSON(message)
|
||||
}
|
||||
|
||||
func (c *TestClient) SetTransientData(key string, value interface{}) error {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
|
@ -672,10 +714,9 @@ func (c *TestClient) RunUntilJoinedAndReturn(ctx context.Context, hello ...*Hell
|
|||
if err := checkMessageType(message, "event"); err != nil {
|
||||
ignored = append(ignored, message)
|
||||
continue
|
||||
} else if message.Event.Target != "room" {
|
||||
return nil, nil, fmt.Errorf("Expected event target room, got %+v", message.Event)
|
||||
} else if message.Event.Type != "join" {
|
||||
return nil, nil, fmt.Errorf("Expected event type join, got %+v", message.Event)
|
||||
} else if message.Event.Target != "room" || message.Event.Type != "join" {
|
||||
ignored = append(ignored, message)
|
||||
continue
|
||||
}
|
||||
|
||||
for len(message.Event.Join) > 0 {
|
||||
|
|
|
@ -56,8 +56,8 @@ func GetVirtualSessionId(session *ClientSession, sessionId string) string {
|
|||
return session.PublicId() + "|" + sessionId
|
||||
}
|
||||
|
||||
func NewVirtualSession(session *ClientSession, privateId string, publicId string, data *SessionIdData, msg *AddSessionInternalClientMessage) *VirtualSession {
|
||||
return &VirtualSession{
|
||||
func NewVirtualSession(session *ClientSession, privateId string, publicId string, data *SessionIdData, msg *AddSessionInternalClientMessage) (*VirtualSession, error) {
|
||||
result := &VirtualSession{
|
||||
hub: session.hub,
|
||||
session: session,
|
||||
privateId: privateId,
|
||||
|
@ -70,6 +70,12 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
|
|||
flags: msg.Flags,
|
||||
options: msg.Options,
|
||||
}
|
||||
|
||||
if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *VirtualSession) PrivateId() string {
|
||||
|
@ -142,6 +148,7 @@ func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *Clie
|
|||
if removed && room != nil {
|
||||
go s.notifyBackendRemoved(room, session, message)
|
||||
}
|
||||
s.session.events.UnregisterSessionListener(s.PublicId(), s.session.Backend(), s)
|
||||
}
|
||||
|
||||
func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession, message *ClientMessage) {
|
||||
|
@ -177,7 +184,10 @@ func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession
|
|||
return
|
||||
}
|
||||
} else {
|
||||
request := NewBackendClientSessionRequest(room.Id(), "remove", s.PublicId(), nil)
|
||||
request := NewBackendClientSessionRequest(room.Id(), "remove", s.PublicId(), &AddSessionInternalClientMessage{
|
||||
UserId: s.userId,
|
||||
User: s.userData,
|
||||
})
|
||||
var response BackendClientSessionResponse
|
||||
err := s.hub.backend.PerformJSONRequest(ctx, s.ParsedBackendUrl(), request, &response)
|
||||
if err != nil {
|
||||
|
@ -252,3 +262,36 @@ func (s *VirtualSession) Flags() uint32 {
|
|||
func (s *VirtualSession) Options() *AddSessionOptions {
|
||||
return s.options
|
||||
}
|
||||
|
||||
func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) {
|
||||
if message.Type == "message" && message.Message != nil {
|
||||
switch message.Message.Type {
|
||||
case "message":
|
||||
if message.Message.Message != nil &&
|
||||
message.Message.Message.Recipient != nil &&
|
||||
message.Message.Message.Recipient.Type == "session" &&
|
||||
message.Message.Message.Recipient.SessionId == s.PublicId() {
|
||||
// The client should see his session id as recipient.
|
||||
message.Message.Message.Recipient = &MessageClientMessageRecipient{
|
||||
Type: "session",
|
||||
SessionId: s.SessionId(),
|
||||
UserId: s.UserId(),
|
||||
}
|
||||
s.session.ProcessAsyncSessionMessage(message)
|
||||
}
|
||||
case "control":
|
||||
if message.Message.Control != nil &&
|
||||
message.Message.Control.Recipient != nil &&
|
||||
message.Message.Control.Recipient.Type == "session" &&
|
||||
message.Message.Control.Recipient.SessionId == s.PublicId() {
|
||||
// The client should see his session id as recipient.
|
||||
message.Message.Control.Recipient = &MessageClientMessageRecipient{
|
||||
Type: "session",
|
||||
SessionId: s.SessionId(),
|
||||
UserId: s.UserId(),
|
||||
}
|
||||
s.session.ProcessAsyncSessionMessage(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue