From 66e502dc9b37a741e23849c688678a035e3ae81e Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 8 Jul 2022 16:53:45 +0200 Subject: [PATCH] Fix handling of "unshareScreen" messages and add test. Also update tests for "requestoffer" / "sendoffer". --- hub.go | 46 ++++++------ hub_test.go | 183 ++++++++++++++++++++++++++++++++++++++------- mcu_test.go | 72 ++++++++++++++++-- testclient_test.go | 31 ++++++++ 4 files changed, 274 insertions(+), 58 deletions(-) diff --git a/hub.go b/hub.go index 893f673..d82a82d 100644 --- a/hub.go +++ b/hub.go @@ -1243,6 +1243,29 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { case "candidate": h.processMcuMessage(session, message, msg, clientData) return + case "unshareScreen": + if msg.Recipient.SessionId == session.PublicId() { + // User is stopping to share his screen. Firefox doesn't properly clean + // up the peer connections in all cases, so make sure to stop publishing + // in the MCU. + go func(c *Client) { + time.Sleep(cleanupScreenPublisherDelay) + session := c.GetSession() + if session == nil { + return + } + + publisher := session.GetPublisher(streamTypeScreen) + if publisher == nil { + return + } + + log.Printf("Closing screen publisher for %s", session.PublicId()) + ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout) + defer cancel() + publisher.Close(ctx) + }(client) + } } } } @@ -1313,29 +1336,6 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { return } - if clientData != nil && clientData.Type == "unshareScreen" { - // User is stopping to share his screen. Firefox doesn't properly clean - // up the peer connections in all cases, so make sure to stop publishing - // in the MCU. - go func(c *Client) { - time.Sleep(cleanupScreenPublisherDelay) - session := c.GetSession() - if session == nil { - return - } - - publisher := session.GetPublisher(streamTypeScreen) - if publisher == nil { - return - } - - log.Printf("Closing screen publisher for %s", session.PublicId()) - ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout) - defer cancel() - publisher.Close(ctx) - }(client) - } - response := &ServerMessage{ Type: "message", Message: &MessageServerMessage{ diff --git a/hub_test.go b/hub_test.go index ad25262..f9f41f0 100644 --- a/hub_test.go +++ b/hub_test.go @@ -3148,6 +3148,24 @@ func TestClientSendOfferPermissions(t *testing.T) { } } + if err := client1.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "offer", + Sid: "12345", + RoomType: "screen", + Payload: map[string]interface{}{ + "sdp": MockSdpOfferAudioAndVideo, + }, + }); err != nil { + t.Fatal(err) + } + + if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { + t.Fatal(err) + } + // Client 1 may send an offer. if err := client1.SendMessage(MessageClientMessageRecipient{ Type: "session", @@ -3160,25 +3178,19 @@ func TestClientSendOfferPermissions(t *testing.T) { t.Fatal(err) } - // The test MCU doesn't support clients yet, so an error will be returned - // to the client trying to send the offer. - if msg, err := client1.RunUntilMessage(ctx); err != nil { - t.Fatal(err) - } else { - if err := checkMessageError(msg, "client_not_found"); err != nil { - t.Fatal(err) - } - } - - ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) + // The sender won't get a reply... + ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel2() - if msg, err := client2.RunUntilMessage(ctx2); err != nil { - if err != context.DeadlineExceeded { - t.Fatal(err) - } - } else { - t.Errorf("Expected no payload, got %+v", msg) + if message, err := client1.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { + t.Error(err) + } else if message != nil { + t.Errorf("Expected no message, got %+v", message) + } + + // ...but the other peer will get an offer. + if err := client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo); err != nil { + t.Fatal(err) } } @@ -3321,7 +3333,6 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { // Client is allowed to send audio and video. session1.SetPermissions([]Permission{PERMISSION_MAY_PUBLISH_AUDIO, PERMISSION_MAY_PUBLISH_VIDEO}) - // Client may send an offer (audio and video). if err := client1.SendMessage(MessageClientMessageRecipient{ Type: "session", SessionId: hello1.Hello.SessionId, @@ -3600,6 +3611,24 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { t.Error(err) } + if err := client1.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "offer", + Sid: "54321", + RoomType: "screen", + Payload: map[string]interface{}{ + "sdp": MockSdpOfferAudioAndVideo, + }, + }); err != nil { + t.Fatal(err) + } + + if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { + t.Fatal(err) + } + // Client 2 may not request an offer (he is not in the room yet). if err := client2.SendMessage(MessageClientMessageRecipient{ Type: "session", @@ -3724,13 +3753,8 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { t.Fatal(err) } - if msg, err := client2.RunUntilMessage(ctx); err != nil { + if err := client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo); err != nil { t.Fatal(err) - } else { - // We check for "client_not_found" as the testing MCU doesn't support publishing/subscribing. - if err := checkMessageError(msg, "client_not_found"); err != nil { - t.Fatal(err) - } } }) } @@ -4017,13 +4041,114 @@ func TestClientSendOffer(t *testing.T) { t.Fatal(err) } - if msg, err := client1.RunUntilMessage(ctx); err != nil { + // The sender won't get a reply... + ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel2() + + if message, err := client1.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded { + t.Error(err) + } else if message != nil { + t.Errorf("Expected no message, got %+v", message) + } + + // ...but the other peer will get an offer. + if err := client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo); err != nil { t.Fatal(err) - } else { - if err := checkMessageError(msg, "client_not_found"); err != nil { - t.Fatal(err) - } } }) } } + +func TestClientUnshareScreen(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + mcu, err := NewTestMCU() + if err != nil { + t.Fatal(err) + } else if err := mcu.Start(); err != nil { + t.Fatal(err) + } + defer mcu.Stop() + + hub.SetMcu(mcu) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewTestClient(t, server, hub) + defer client1.CloseWithBye() + + if err := client1.SendHello(testDefaultUserId + "1"); err != nil { + t.Fatal(err) + } + + hello1, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + + // Join room by id. + roomId := "test-room" + if room, err := client1.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + + if err := client1.RunUntilJoined(ctx, hello1.Hello); err != nil { + t.Error(err) + } + + session1 := hub.GetSessionByPublicId(hello1.Hello.SessionId).(*ClientSession) + if session1 == nil { + t.Fatalf("Session %s does not exist", hello1.Hello.SessionId) + } + + if err := client1.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "offer", + Sid: "54321", + RoomType: "screen", + Payload: map[string]interface{}{ + "sdp": MockSdpOfferAudioOnly, + }, + }); err != nil { + t.Fatal(err) + } + + if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioOnly); err != nil { + t.Fatal(err) + } + + publisher := mcu.GetPublisher(hello1.Hello.SessionId) + if publisher == nil { + t.Fatalf("No publisher for %s found", hello1.Hello.SessionId) + } else if publisher.isClosed() { + t.Fatalf("Publisher %s should not be closed", hello1.Hello.SessionId) + } + + old := cleanupScreenPublisherDelay + cleanupScreenPublisherDelay = time.Millisecond + defer func() { + cleanupScreenPublisherDelay = old + }() + + if err := client1.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "unshareScreen", + Sid: "54321", + RoomType: "screen", + }); err != nil { + t.Fatal(err) + } + + time.Sleep(10 * time.Millisecond) + + if !publisher.isClosed() { + t.Fatalf("Publisher %s should be closed", hello1.Hello.SessionId) + } +} diff --git a/mcu_test.go b/mcu_test.go index 4172ed8..a95c880 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -37,13 +37,15 @@ const ( ) type TestMCU struct { - mu sync.Mutex - publishers map[string]*TestMCUPublisher + mu sync.Mutex + publishers map[string]*TestMCUPublisher + subscribers map[string]*TestMCUSubscriber } func NewTestMCU() (*TestMCU, error) { return &TestMCU{ - publishers: make(map[string]*TestMCUPublisher), + publishers: make(map[string]*TestMCUPublisher), + subscribers: make(map[string]*TestMCUSubscriber), }, nil } @@ -116,7 +118,24 @@ func (m *TestMCU) GetPublisher(id string) *TestMCUPublisher { } func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { - return nil, fmt.Errorf("Not implemented") + m.mu.Lock() + defer m.mu.Unlock() + + pub := m.publishers[publisher] + if pub == nil { + return nil, fmt.Errorf("Waiting for publisher not implemented yet") + } + + id := newRandomString(8) + sub := &TestMCUSubscriber{ + TestMCUClient: TestMCUClient{ + id: id, + streamType: streamType, + }, + + publisher: pub, + } + return sub, nil } type TestMCUClient struct { @@ -140,8 +159,9 @@ func (c *TestMCUClient) StreamType() string { } func (c *TestMCUClient) Close(ctx context.Context) { - log.Printf("Close MCU client %s", c.id) - atomic.StoreInt32(&c.closed, 1) + if atomic.CompareAndSwapInt32(&c.closed, 0, 1) { + log.Printf("Close MCU client %s", c.id) + } } func (c *TestMCUClient) isClosed() bool { @@ -153,6 +173,8 @@ type TestMCUPublisher struct { mediaTypes MediaType bitrate int + + sdp string } func (p *TestMCUPublisher) HasMedia(mt MediaType) bool { @@ -174,6 +196,7 @@ func (p *TestMCUPublisher) SendMessage(ctx context.Context, message *MessageClie case "offer": sdp := data.Payload["sdp"] if sdp, ok := sdp.(string); ok { + p.sdp = sdp if sdp == MockSdpOfferAudioOnly { callback(nil, map[string]interface{}{ "type": "answer", @@ -194,3 +217,40 @@ func (p *TestMCUPublisher) SendMessage(ctx context.Context, message *MessageClie } }() } + +type TestMCUSubscriber struct { + TestMCUClient + + publisher *TestMCUPublisher +} + +func (s *TestMCUSubscriber) Publisher() string { + return s.publisher.id +} + +func (s *TestMCUSubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { + go func() { + if s.isClosed() { + callback(fmt.Errorf("Already closed"), nil) + return + } + + switch data.Type { + case "requestoffer": + fallthrough + case "sendoffer": + sdp := s.publisher.sdp + if sdp == "" { + callback(fmt.Errorf("Publisher not sending (no SDP)"), nil) + return + } + + callback(nil, map[string]interface{}{ + "type": "offer", + "sdp": sdp, + }) + default: + callback(fmt.Errorf("Message type %s is not implemented", data.Type), nil) + } + }() +} diff --git a/testclient_test.go b/testclient_test.go index c54eb6d..dbe3101 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -837,6 +837,37 @@ func checkMessageError(message *ServerMessage, msgid string) error { return nil } +func (c *TestClient) RunUntilOffer(ctx context.Context, offer string) error { + message, err := c.RunUntilMessage(ctx) + if err != nil { + return err + } + if err := checkUnexpectedClose(err); err != nil { + return err + } else if err := checkMessageType(message, "message"); err != nil { + return err + } + + var data map[string]interface{} + if err := json.Unmarshal(*message.Message.Data, &data); err != nil { + return err + } + + if data["type"].(string) != "offer" { + return fmt.Errorf("expected data type offer, got %+v", data) + } + + payload := data["payload"].(map[string]interface{}) + if payload["type"].(string) != "offer" { + return fmt.Errorf("expected payload type offer, got %+v", payload) + } + if payload["sdp"].(string) != offer { + return fmt.Errorf("expected payload answer %s, got %+v", offer, payload) + } + + return nil +} + func (c *TestClient) RunUntilAnswer(ctx context.Context, answer string) error { message, err := c.RunUntilMessage(ctx) if err != nil {