From 707b1257309f5ad66f6b67f9c0dc41eaa6f72532 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 8 Nov 2021 12:06:59 +0100 Subject: [PATCH] Check individual audio/video permissions on change. If a client publishes audio/video and no longer has the video permission, the whole publisher will be closed. Previously this was only checking the generic "media" permission. --- api_proxy.go | 9 +-- clientsession.go | 159 ++++++++++++++++++++++++++++++++++++++++-- hub.go | 64 +++-------------- hub_test.go | 81 +++++++++++++++++---- mcu_common.go | 12 +++- mcu_janus.go | 18 +++-- mcu_proxy.go | 21 ++++-- mcu_test.go | 112 +++++++++++++++++++++++++++-- mock_data_test.go | 67 ++++++++++++++++++ proxy/proxy_server.go | 10 ++- testclient_test.go | 31 ++++++++ 11 files changed, 486 insertions(+), 98 deletions(-) diff --git a/api_proxy.go b/api_proxy.go index d9c36dc..8ffd3c9 100644 --- a/api_proxy.go +++ b/api_proxy.go @@ -172,10 +172,11 @@ type ByeProxyServerMessage struct { type CommandProxyClientMessage struct { Type string `json:"type"` - StreamType string `json:"streamType,omitempty"` - PublisherId string `json:"publisherId,omitempty"` - ClientId string `json:"clientId,omitempty"` - Bitrate int `json:"bitrate,omitempty"` + StreamType string `json:"streamType,omitempty"` + PublisherId string `json:"publisherId,omitempty"` + ClientId string `json:"clientId,omitempty"` + Bitrate int `json:"bitrate,omitempty"` + MediaTypes MediaType `json:"mediatypes,omitempty"` } func (m *CommandProxyClientMessage) CheckValid() error { diff --git a/clientsession.go b/clientsession.go index 4727666..acc92bb 100644 --- a/clientsession.go +++ b/clientsession.go @@ -24,6 +24,7 @@ package signaling import ( "context" "encoding/json" + "fmt" "log" "net/url" "strings" @@ -33,6 +34,7 @@ import ( "unsafe" "github.com/nats-io/nats.go" + "github.com/pion/sdp" ) var ( @@ -192,6 +194,14 @@ func (s *ClientSession) HasAnyPermission(permission ...Permission) bool { s.mu.Lock() defer s.mu.Unlock() + return s.hasAnyPermissionLocked(permission...) +} + +func (s *ClientSession) hasAnyPermissionLocked(permission ...Permission) bool { + if len(permission) == 0 { + return false + } + for _, p := range permission { if s.hasPermissionLocked(p) { return true @@ -671,10 +681,140 @@ func (s *ClientSession) SubscriberClosed(subscriber McuSubscriber) { } } -func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, streamType string) (McuPublisher, error) { +type SdpError struct { + message string +} + +func (e *SdpError) Error() string { + return e.message +} + +type WrappedSdpError struct { + SdpError + err error +} + +func (e *WrappedSdpError) Unwrap() error { + return e.err +} + +type PermissionError struct { + permission Permission +} + +func (e *PermissionError) Permission() Permission { + return e.permission +} + +func (e *PermissionError) Error() string { + return fmt.Sprintf("permission \"%s\" not found", e.permission) +} + +func (s *ClientSession) isSdpAllowedToSendLocked(payload map[string]interface{}) (MediaType, error) { + sdpValue, found := payload["sdp"] + if !found { + return 0, &SdpError{"payload does not contain a sdp"} + } + sdpText, ok := sdpValue.(string) + if !ok { + return 0, &SdpError{"payload does not contain a valid sdp"} + } + var sdp sdp.SessionDescription + if err := sdp.Unmarshal(sdpText); err != nil { + return 0, &WrappedSdpError{ + SdpError: SdpError{ + message: fmt.Sprintf("could not parse sdp: %s", err), + }, + err: err, + } + } + + var mediaTypes MediaType + mayPublishMedia := s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_MEDIA) + for _, md := range sdp.MediaDescriptions { + switch md.MediaName.Media { + case "audio": + if !mayPublishMedia && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_AUDIO) { + return 0, &PermissionError{PERMISSION_MAY_PUBLISH_AUDIO} + } + + mediaTypes |= MediaTypeAudio + case "video": + if !mayPublishMedia && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_VIDEO) { + return 0, &PermissionError{PERMISSION_MAY_PUBLISH_VIDEO} + } + + mediaTypes |= MediaTypeVideo + } + } + + return mediaTypes, nil +} + +func (s *ClientSession) IsAllowedToSend(data *MessageClientMessageData) error { s.mu.Lock() defer s.mu.Unlock() + if data != nil && data.RoomType == "screen" { + if s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { + return nil + } + return &PermissionError{PERMISSION_MAY_PUBLISH_SCREEN} + } else if s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_MEDIA) { + // Client is allowed to publish any media (audio / video). + return nil + } else if data != nil && data.Type == "offer" { + // Parse SDP to check what user is trying to publish and check permissions accordingly. + if _, err := s.isSdpAllowedToSendLocked(data.Payload); err != nil { + return err + } + + return nil + } else { + // Candidate or unknown event, check if client is allowed to publish any media. + if s.hasAnyPermissionLocked(PERMISSION_MAY_PUBLISH_AUDIO, PERMISSION_MAY_PUBLISH_VIDEO) { + return nil + } + + return fmt.Errorf("permission check failed") + } +} + +func (s *ClientSession) CheckOfferType(streamType string, data *MessageClientMessageData) (MediaType, error) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.checkOfferTypeLocked(streamType, data) +} + +func (s *ClientSession) checkOfferTypeLocked(streamType string, data *MessageClientMessageData) (MediaType, error) { + if streamType == streamTypeScreen { + if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { + return 0, &PermissionError{PERMISSION_MAY_PUBLISH_SCREEN} + } + + return MediaTypeScreen, nil + } else if data != nil && data.Type == "offer" { + mediaTypes, err := s.isSdpAllowedToSendLocked(data.Payload) + if err != nil { + return 0, err + } + + return mediaTypes, nil + } + + return 0, nil +} + +func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, streamType string, data *MessageClientMessageData) (McuPublisher, error) { + s.mu.Lock() + defer s.mu.Unlock() + + mediaTypes, err := s.checkOfferTypeLocked(streamType, data) + if err != nil { + return nil, err + } + publisher, found := s.publishers[streamType] if !found { client := s.getClientUnlocked() @@ -689,7 +829,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea } } var err error - publisher, err = mcu.NewPublisher(ctx, s, s.PublicId(), streamType, bitrate, client) + publisher, err = mcu.NewPublisher(ctx, s, s.PublicId(), streamType, bitrate, mediaTypes, client) s.mu.Lock() if err != nil { return nil, err @@ -777,11 +917,15 @@ func (s *ClientSession) processClientMessage(msg *nats.Msg) { if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_MEDIA) { if publisher, found := s.publishers[streamTypeVideo]; found { - delete(s.publishers, streamTypeVideo) - log.Printf("Session %s is no longer allowed to publish media, closing publisher %s", s.PublicId(), publisher.Id()) - go func() { - publisher.Close(context.Background()) - }() + if (publisher.HasMedia(MediaTypeAudio) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_AUDIO)) || + (publisher.HasMedia(MediaTypeVideo) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_VIDEO)) { + delete(s.publishers, streamTypeVideo) + log.Printf("Session %s is no longer allowed to publish media, closing publisher %s", s.PublicId(), publisher.Id()) + go func() { + publisher.Close(context.Background()) + }() + return + } } } if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { @@ -791,6 +935,7 @@ func (s *ClientSession) processClientMessage(msg *nats.Msg) { go func() { publisher.Close(context.Background()) }() + return } } }() diff --git a/hub.go b/hub.go index e2965a0..cbb27a6 100644 --- a/hub.go +++ b/hub.go @@ -42,7 +42,6 @@ import ( "github.com/gorilla/mux" "github.com/gorilla/securecookie" "github.com/gorilla/websocket" - "github.com/pion/sdp" ) var ( @@ -1376,7 +1375,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { if recipient != nil { // The recipient is connected to this instance, no need to go through NATS. if clientData != nil && clientData.Type == "sendoffer" { - if err := isAllowedToSend(session, clientData); err != nil { + if err := session.IsAllowedToSend(clientData); err != nil { log.Printf("Session %s is not allowed to send offer for %s, ignoring (%s)", session.PublicId(), clientData.RoomType, err) sendNotAllowed(session, message, "Not allowed to send offer") return @@ -1653,53 +1652,6 @@ func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { } } -func isAllowedToSend(session *ClientSession, data *MessageClientMessageData) error { - if data.RoomType == "screen" { - if session.HasPermission(PERMISSION_MAY_PUBLISH_SCREEN) { - return nil - } - return fmt.Errorf("permission \"%s\" not found", PERMISSION_MAY_PUBLISH_SCREEN) - } else if session.HasPermission(PERMISSION_MAY_PUBLISH_MEDIA) { - // Client is allowed to publish any media (audio / video). - return nil - } else if data != nil && data.Type == "offer" { - // Parse SDP to check what user is trying to publish and check permissions accordingly. - sdpValue, found := data.Payload["sdp"] - if !found { - return fmt.Errorf("offer does not contain a sdp") - } - sdpText, ok := sdpValue.(string) - if !ok { - return fmt.Errorf("offer does not contain a valid sdp") - } - var s sdp.SessionDescription - if err := s.Unmarshal(sdpText); err != nil { - return fmt.Errorf("could not parse sdp: %w", err) - } - for _, md := range s.MediaDescriptions { - switch md.MediaName.Media { - case "audio": - if !session.HasPermission(PERMISSION_MAY_PUBLISH_AUDIO) { - return fmt.Errorf("permission \"%s\" not found", PERMISSION_MAY_PUBLISH_AUDIO) - } - case "video": - if !session.HasPermission(PERMISSION_MAY_PUBLISH_VIDEO) { - return fmt.Errorf("permission \"%s\" not found", PERMISSION_MAY_PUBLISH_VIDEO) - } - } - } - return nil - } else { - // Candidate or unknown event, check if client is allowed to publish any media. - if session.HasAnyPermission(PERMISSION_MAY_PUBLISH_AUDIO, PERMISSION_MAY_PUBLISH_VIDEO) { - return nil - } - - return fmt.Errorf("permission check failed") - } - -} - func sendNotAllowed(session *ClientSession, message *ClientMessage, reason string) { response := message.NewErrorServerMessage(NewError("not_allowed", reason)) session.SendMessage(response) @@ -1772,14 +1724,18 @@ func (h *Hub) processMcuMessage(senderSession *ClientSession, session *ClientSes clientType = "subscriber" mc, err = session.GetOrCreateSubscriber(ctx, h.mcu, message.Recipient.SessionId, data.RoomType) case "offer": - if err := isAllowedToSend(session, data); err != nil { + clientType = "publisher" + mc, err = session.GetOrCreatePublisher(ctx, h.mcu, data.RoomType, data) + if err, ok := err.(*PermissionError); ok { log.Printf("Session %s is not allowed to offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err) sendNotAllowed(senderSession, client_message, "Not allowed to publish.") return } - - clientType = "publisher" - mc, err = session.GetOrCreatePublisher(ctx, h.mcu, data.RoomType) + if err, ok := err.(*SdpError); ok { + log.Printf("Session %s sent unsupported offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err) + sendNotAllowed(senderSession, client_message, "Not allowed to publish.") + return + } case "selectStream": if session.PublicId() == message.Recipient.SessionId { log.Printf("Not selecting substream for own %s stream in session %s", data.RoomType, session.PublicId()) @@ -1790,7 +1746,7 @@ func (h *Hub) processMcuMessage(senderSession *ClientSession, session *ClientSes mc = session.GetSubscriber(message.Recipient.SessionId, data.RoomType) default: if session.PublicId() == message.Recipient.SessionId { - if err := isAllowedToSend(session, data); err != nil { + if err := session.IsAllowedToSend(data); err != nil { log.Printf("Session %s is not allowed to send candidate for %s, ignoring (%s)", session.PublicId(), data.RoomType, err) sendNotAllowed(senderSession, client_message, "Not allowed to send candidate.") return diff --git a/hub_test.go b/hub_test.go index feb16a1..90143d7 100644 --- a/hub_test.go +++ b/hub_test.go @@ -114,6 +114,13 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve if err != nil { t.Fatal(err) } + b, err := NewBackendServer(config, h, "no-version") + if err != nil { + t.Fatal(err) + } + if err := b.Start(r); err != nil { + t.Fatal(err) + } go h.Run() @@ -2536,14 +2543,8 @@ func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { t.Fatal(err) } - // The test MCU doesn't support clients yet, so an error will be returned - // to the client trying to send the offer. - if msg, err := client1.RunUntilMessage(ctx); err != nil { + if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioOnly); err != nil { t.Fatal(err) - } else { - if err := checkMessageError(msg, "client_not_found"); err != nil { - t.Fatal(err) - } } } @@ -2611,14 +2612,68 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { t.Fatal(err) } - // The test MCU doesn't support clients yet, so an error will be returned - // to the client trying to send the offer. - if msg, err := client1.RunUntilMessage(ctx); err != nil { + if err := client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo); err != nil { t.Fatal(err) - } else { - if err := checkMessageError(msg, "client_not_found"); err != nil { - t.Fatal(err) + } + + // Client is no longer allowed to send video, this will stop the publisher. + msg := &BackendServerRoomRequest{ + Type: "participants", + Participants: &BackendRoomParticipantsRequest{ + Changed: []map[string]interface{}{ + { + "sessionId": roomId + "-" + hello1.Hello.SessionId, + "permissions": []Permission{PERMISSION_MAY_PUBLISH_AUDIO}, + }, + }, + Users: []map[string]interface{}{ + { + "sessionId": roomId + "-" + hello1.Hello.SessionId, + "permissions": []Permission{PERMISSION_MAY_PUBLISH_AUDIO}, + }, + }, + }, + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatal(err) + } + res, err := performBackendRequest(server.URL+"/api/v1/room/"+roomId, data) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + if res.StatusCode != 200 { + t.Errorf("Expected successful request, got %s: %s", res.Status, string(body)) + } + + ctx2, cancel2 := context.WithTimeout(ctx, time.Second) + defer cancel2() + + pubs := mcu.GetPublishers() + if len(pubs) != 1 { + t.Fatalf("expected one publisher, got %+v", pubs) + } + +loop: + for { + if err := ctx2.Err(); err != nil { + t.Errorf("publisher was not closed: %s", err) } + + for _, pub := range pubs { + if pub.isClosed() { + break loop + } + } + + // Give some time to async processing. + time.Sleep(time.Millisecond) } } diff --git a/mcu_common.go b/mcu_common.go index 6ac89ed..a5bf7a9 100644 --- a/mcu_common.go +++ b/mcu_common.go @@ -39,6 +39,14 @@ var ( ErrNotConnected = fmt.Errorf("not connected") ) +type MediaType int + +const ( + MediaTypeAudio MediaType = 1 << 0 + MediaTypeVideo MediaType = 1 << 1 + MediaTypeScreen MediaType = 1 << 2 +) + type McuListener interface { PublicId() string @@ -63,7 +71,7 @@ type Mcu interface { GetStats() interface{} - NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, initiator McuInitiator) (McuPublisher, error) + NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) } @@ -78,6 +86,8 @@ type McuClient interface { type McuPublisher interface { McuClient + + HasMedia(MediaType) bool } type McuSubscriber interface { diff --git a/mcu_janus.go b/mcu_janus.go index 81ae35f..3e147ec 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -682,9 +682,10 @@ func (c *publisherStatsCounter) RemoveSubscriber(id string) { type mcuJanusPublisher struct { mcuJanusClient - id string - bitrate int - stats publisherStatsCounter + id string + bitrate int + mediaTypes MediaType + stats publisherStatsCounter } func (m *mcuJanus) SubscriberConnected(id string, publisher string, streamType string) { @@ -781,7 +782,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st return handle, response.Session, roomId, nil } -func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, initiator McuInitiator) (McuPublisher, error) { +func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { if _, found := streamTypeUserIds[streamType]; !found { return nil, fmt.Errorf("Unsupported stream type %s", streamType) } @@ -806,8 +807,9 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id st closeChan: make(chan bool, 1), deferred: make(chan func(), 64), }, - id: id, - bitrate: bitrate, + id: id, + bitrate: bitrate, + mediaTypes: mediaTypes, } client.mcuJanusClient.handleEvent = client.handleEvent client.mcuJanusClient.handleHangup = client.handleHangup @@ -878,6 +880,10 @@ func (p *mcuJanusPublisher) handleMedia(event *janus.MediaMsg) { p.stats.EnableStream(mediaType, event.Receiving) } +func (p *mcuJanusPublisher) HasMedia(mt MediaType) bool { + return (p.mediaTypes & mt) == mt +} + func (p *mcuJanusPublisher) NotifyReconnected() { ctx := context.TODO() handle, session, roomId, err := p.mcu.getOrCreatePublisherHandle(ctx, p.id, p.streamType, p.bitrate) diff --git a/mcu_proxy.go b/mcu_proxy.go index 95dd074..63edca6 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -117,10 +117,11 @@ func (c *mcuProxyPubSubCommon) doProcessPayload(client McuClient, msg *PayloadPr type mcuProxyPublisher struct { mcuProxyPubSubCommon - id string + id string + mediaTypes MediaType } -func newMcuProxyPublisher(id string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { +func newMcuProxyPublisher(id string, streamType string, mediaTypes MediaType, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { return &mcuProxyPublisher{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{ streamType: streamType, @@ -128,10 +129,15 @@ func newMcuProxyPublisher(id string, streamType string, proxyId string, conn *mc conn: conn, listener: listener, }, - id: id, + id: id, + mediaTypes: mediaTypes, } } +func (p *mcuProxyPublisher) HasMedia(mt MediaType) bool { + return (p.mediaTypes & mt) == mt +} + func (p *mcuProxyPublisher) NotifyClosed() { p.listener.PublisherClosed(p) p.conn.removePublisher(p) @@ -920,13 +926,14 @@ func (c *mcuProxyConnection) performSyncRequest(ctx context.Context, msg *ProxyC } } -func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int) (McuPublisher, error) { +func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, mediaTypes MediaType) (McuPublisher, error) { msg := &ProxyClientMessage{ Type: "command", Command: &CommandProxyClientMessage{ Type: "create-publisher", StreamType: streamType, Bitrate: bitrate, + MediaTypes: mediaTypes, }, } @@ -938,7 +945,7 @@ func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListe proxyId := response.Command.Id log.Printf("Created %s publisher %s on %s for %s", streamType, proxyId, c.url, id) - publisher := newMcuProxyPublisher(id, streamType, proxyId, c, listener) + publisher := newMcuProxyPublisher(id, streamType, mediaTypes, proxyId, c, listener) c.publishersLock.Lock() c.publishers[proxyId] = publisher c.publisherIds[id+"|"+streamType] = proxyId @@ -1679,7 +1686,7 @@ func (m *mcuProxy) removeWaiter(id uint64) { delete(m.publisherWaiters, id) } -func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, initiator McuInitiator) (McuPublisher, error) { +func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { connections := m.getSortedConnections(initiator) for _, conn := range connections { if conn.IsShutdownScheduled() { @@ -1700,7 +1707,7 @@ func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id st } else { bitrate = min(bitrate, maxBitrate) } - publisher, err := conn.newPublisher(subctx, listener, id, streamType, bitrate) + publisher, err := conn.newPublisher(subctx, listener, id, streamType, bitrate, mediaTypes) if err != nil { log.Printf("Could not create %s publisher for %s on %s: %s", streamType, id, conn.url, err) continue diff --git a/mcu_test.go b/mcu_test.go index 6e59a81..a08c9d0 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -24,15 +24,22 @@ package signaling import ( "context" "fmt" + "log" + "sync" + "sync/atomic" "github.com/dlintw/goconf" ) type TestMCU struct { + mu sync.Mutex + publishers map[string]*TestMCUPublisher } -func NewTestMCU() (Mcu, error) { - return &TestMCU{}, nil +func NewTestMCU() (*TestMCU, error) { + return &TestMCU{ + publishers: make(map[string]*TestMCUPublisher), + }, nil } func (m *TestMCU) Start() error { @@ -55,10 +62,107 @@ func (m *TestMCU) GetStats() interface{} { return nil } -func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, initiator McuInitiator) (McuPublisher, error) { - return nil, fmt.Errorf("Not implemented") +func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { + pub := &TestMCUPublisher{ + TestMCUClient: TestMCUClient{ + id: id, + streamType: streamType, + }, + + mediaTypes: mediaTypes, + } + + m.mu.Lock() + defer m.mu.Unlock() + + m.publishers[id] = pub + return pub, nil +} + +func (m *TestMCU) GetPublishers() map[string]*TestMCUPublisher { + m.mu.Lock() + defer m.mu.Unlock() + + result := make(map[string]*TestMCUPublisher, len(m.publishers)) + for id, pub := range m.publishers { + result[id] = pub + } + return result +} + +func (m *TestMCU) GetPublisher(id string) *TestMCUPublisher { + m.mu.Lock() + defer m.mu.Unlock() + + return m.publishers[id] } func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { return nil, fmt.Errorf("Not implemented") } + +type TestMCUClient struct { + closed int32 + + id string + streamType string +} + +func (c *TestMCUClient) Id() string { + return c.id +} + +func (c *TestMCUClient) StreamType() string { + return c.streamType +} + +func (c *TestMCUClient) Close(ctx context.Context) { + log.Printf("Close MCU client %s", c.id) + atomic.StoreInt32(&c.closed, 1) +} + +func (c *TestMCUClient) isClosed() bool { + return atomic.LoadInt32(&c.closed) != 0 +} + +type TestMCUPublisher struct { + TestMCUClient + + mediaTypes MediaType +} + +func (p *TestMCUPublisher) HasMedia(mt MediaType) bool { + return (p.mediaTypes & mt) == mt +} + +func (p *TestMCUPublisher) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { + go func() { + if p.isClosed() { + callback(fmt.Errorf("Already closed"), nil) + return + } + + switch data.Type { + case "offer": + sdp := data.Payload["sdp"] + if sdp, ok := sdp.(string); ok { + if sdp == MockSdpOfferAudioOnly { + callback(nil, map[string]interface{}{ + "type": "answer", + "sdp": MockSdpAnswerAudioOnly, + }) + return + } else if sdp == MockSdpOfferAudioAndVideo { + callback(nil, map[string]interface{}{ + "type": "answer", + "sdp": MockSdpAnswerAudioAndVideo, + }) + return + } + } + callback(fmt.Errorf("Offer payload %+v is not implemented", data.Payload), nil) + default: + callback(fmt.Errorf("Message type %s is not implemented", data.Type), nil) + } + }() +} diff --git a/mock_data_test.go b/mock_data_test.go index ccf35fa..ed0063b 100644 --- a/mock_data_test.go +++ b/mock_data_test.go @@ -32,6 +32,34 @@ a=candidate:1 1 UDP 1685987071 192.168.0.1 54609 typ srflx raddr 192.0.2.4 rport a=candidate:0 2 UDP 2122194687 192.0.2.4 61667 typ host a=candidate:1 2 UDP 1685987071 192.168.0.1 60065 typ srflx raddr 192.0.2.4 rport 61667 a=end-of-candidates +` + MockSdpAnswerAudioOnly = `v=0 +o=- 16833 0 IN IP4 0.0.0.0 +s=- +t=0 0 +a=group:BUNDLE audio +a=ice-options:trickle +m=audio 49203 UDP/TLS/RTP/SAVPF 109 0 8 +c=IN IP4 192.168.0.1 +a=mid:audio +a=msid:ma ta +a=sendrecv +a=rtpmap:109 opus/48000/2 +a=rtpmap:0 PCMU/8000 +a=rtpmap:8 PCMA/8000 +a=maxptime:120 +a=ice-ufrag:05067423 +a=ice-pwd:1747d1ee3474a28a397a4c3f3af08a068 +a=fingerprint:sha-256 6B:8B:F0:65:5F:78:E2:51:3B:AC:6F:F3:3F:46:1B:35:DC:B8:5F:64:1A:24:C2:43:F0:A1:58:D0:A1:2C:19:08 +a=setup:active +a=tls-id:1 +a=rtcp-mux +a=rtcp-rsize +a=extmap:1 urn:ietf:params:rtp-hdrext:ssrc-audio-level +a=extmap:2 urn:ietf:params:rtp-hdrext:sdes:mid +a=candidate:0 1 UDP 2122194687 198.51.100.7 51556 typ host +a=candidate:1 1 UDP 1685987071 192.168.0.1 49203 typ srflx raddr 198.51.100.7 rport 51556 +a=end-of-candidates ` // See https://tools.ietf.org/id/draft-ietf-rtcweb-sdp-08.html#rfc.section.5.2.2.1 @@ -80,5 +108,44 @@ a=rtcp-fb:120 nack a=rtcp-fb:120 nack pli a=rtcp-fb:120 ccm fir a=extmap:2 urn:ietf:params:rtp-hdrext:sdes:mid +` + MockSdpAnswerAudioAndVideo = `v=0 +o=- 16833 0 IN IP4 0.0.0.0 +s=- +t=0 0 +a=group:BUNDLE audio +a=ice-options:trickle +m=audio 49203 UDP/TLS/RTP/SAVPF 109 0 8 +c=IN IP4 192.168.0.1 +a=mid:audio +a=msid:ma ta +a=sendrecv +a=rtpmap:109 opus/48000/2 +a=rtpmap:0 PCMU/8000 +a=rtpmap:8 PCMA/8000 +a=maxptime:120 +a=ice-ufrag:05067423 +a=ice-pwd:1747d1ee3474a28a397a4c3f3af08a068 +a=fingerprint:sha-256 6B:8B:F0:65:5F:78:E2:51:3B:AC:6F:F3:3F:46:1B:35:DC:B8:5F:64:1A:24:C2:43:F0:A1:58:D0:A1:2C:19:08 +a=setup:active +a=tls-id:1 +a=rtcp-mux +a=rtcp-rsize +a=extmap:1 urn:ietf:params:rtp-hdrext:ssrc-audio-level +a=extmap:2 urn:ietf:params:rtp-hdrext:sdes:mid +a=candidate:0 1 UDP 2122194687 198.51.100.7 51556 typ host +a=candidate:1 1 UDP 1685987071 192.168.0.1 49203 typ srflx raddr 198.51.100.7 rport 51556 +a=end-of-candidates +m=video 49203 UDP/TLS/RTP/SAVPF 99 +c=IN IP4 192.168.0.1 +a=mid:video +a=msid:ma tb +a=sendrecv +a=rtpmap:99 H264/90000 +a=fmtp:99 profile-level-id=4d0028;packetization-mode=1 +a=rtcp-fb:99 nack +a=rtcp-fb:99 nack pli +a=rtcp-fb:99 ccm fir +a=extmap:2 urn:ietf:params:rtp-hdrext:sdes:mid ` ) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 0b26b4d..836293c 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -635,7 +635,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } id := uuid.New().String() - publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType, cmd.Bitrate, &emptyInitiator{}) + publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType, cmd.Bitrate, cmd.MediaTypes, &emptyInitiator{}) if err == context.DeadlineExceeded { log.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId()) session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingPublisher)) @@ -695,7 +695,13 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } - if session.DeletePublisher(client) == "" { + publisher, ok := client.(signaling.McuPublisher) + if !ok { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + if session.DeletePublisher(publisher) == "" { session.sendMessage(message.NewErrorServerMessage(UnknownClient)) return } diff --git a/testclient_test.go b/testclient_test.go index 10818e9..4edba45 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -704,3 +704,34 @@ func checkMessageError(message *ServerMessage, msgid string) error { return nil } + +func (c *TestClient) RunUntilAnswer(ctx context.Context, answer string) error { + message, err := c.RunUntilMessage(ctx) + if err != nil { + return err + } + if err := checkUnexpectedClose(err); err != nil { + return err + } else if err := checkMessageType(message, "message"); err != nil { + return err + } + + var data map[string]interface{} + if err := json.Unmarshal(*message.Message.Data, &data); err != nil { + return err + } + + if data["type"].(string) != "answer" { + return fmt.Errorf("expected data type answer, got %+v", data) + } + + payload := data["payload"].(map[string]interface{}) + if payload["type"].(string) != "answer" { + return fmt.Errorf("expected payload type answer, got %+v", payload) + } + if payload["sdp"].(string) != answer { + return fmt.Errorf("expected payload answer %s, got %+v", answer, payload) + } + + return nil +}