From 7d09c71ab9e6433d409b4d78d2d00eb1ef3be863 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 27 Feb 2024 13:52:59 +0100 Subject: [PATCH] Strongly type "StreamType". --- api_proxy.go | 12 +++--- api_signaling.go | 7 ++++ clientsession.go | 52 ++++++++++++------------ clientsession_test.go | 12 +++--- grpc_client.go | 4 +- grpc_server.go | 2 +- hub.go | 34 ++++++++++++---- mcu_common.go | 27 +++++++++++-- mcu_janus.go | 91 +++++++++++++++++++++--------------------- mcu_proxy.go | 56 +++++++++++++------------- mcu_test.go | 10 ++--- proxy/proxy_server.go | 12 +++--- proxy/proxy_session.go | 8 ++-- 13 files changed, 188 insertions(+), 139 deletions(-) diff --git a/api_proxy.go b/api_proxy.go index 3184be4..62f5197 100644 --- a/api_proxy.go +++ b/api_proxy.go @@ -179,12 +179,12 @@ type ByeProxyServerMessage struct { type CommandProxyClientMessage struct { Type string `json:"type"` - Sid string `json:"sid,omitempty"` - StreamType string `json:"streamType,omitempty"` - PublisherId string `json:"publisherId,omitempty"` - ClientId string `json:"clientId,omitempty"` - Bitrate int `json:"bitrate,omitempty"` - MediaTypes MediaType `json:"mediatypes,omitempty"` + Sid string `json:"sid,omitempty"` + StreamType StreamType `json:"streamType,omitempty"` + PublisherId string `json:"publisherId,omitempty"` + ClientId string `json:"clientId,omitempty"` + Bitrate int `json:"bitrate,omitempty"` + MediaTypes MediaType `json:"mediatypes,omitempty"` } func (m *CommandProxyClientMessage) CheckValid() error { diff --git a/api_signaling.go b/api_signaling.go index 6cfdec5..98523f3 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -565,6 +565,13 @@ type MessageClientMessageData struct { Payload map[string]interface{} `json:"payload"` } +func (m *MessageClientMessageData) CheckValid() error { + if !IsValidStreamType(m.RoomType) { + return fmt.Errorf("invalid room type: %s", m.RoomType) + } + return nil +} + func (m *MessageClientMessage) CheckValid() error { if m.Data == nil || len(*m.Data) == 0 { return fmt.Errorf("message empty") diff --git a/clientsession.go b/clientsession.go index 1567278..53e8e55 100644 --- a/clientsession.go +++ b/clientsession.go @@ -79,7 +79,7 @@ type ClientSession struct { publisherWaiters ChannelWaiters - publishers map[string]McuPublisher + publishers map[StreamType]McuPublisher subscribers map[string]McuSubscriber pendingClientMessages []*ServerMessage @@ -356,7 +356,7 @@ func (s *ClientSession) getRoomJoinTime() time.Time { func (s *ClientSession) releaseMcuObjects() { if len(s.publishers) > 0 { - go func(publishers map[string]McuPublisher) { + go func(publishers map[StreamType]McuPublisher) { ctx := context.TODO() for _, publisher := range publishers { publisher.Close(ctx) @@ -573,12 +573,12 @@ func (s *ClientSession) SetClient(client *Client) *Client { return prev } -func (s *ClientSession) sendOffer(client McuClient, sender string, streamType string, offer map[string]interface{}) { +func (s *ClientSession) sendOffer(client McuClient, sender string, streamType StreamType, offer map[string]interface{}) { offer_message := &AnswerOfferMessage{ To: s.PublicId(), From: sender, Type: "offer", - RoomType: streamType, + RoomType: string(streamType), Payload: offer, Sid: client.Sid(), } @@ -601,12 +601,12 @@ func (s *ClientSession) sendOffer(client McuClient, sender string, streamType st s.sendMessageUnlocked(response_message) } -func (s *ClientSession) sendCandidate(client McuClient, sender string, streamType string, candidate interface{}) { +func (s *ClientSession) sendCandidate(client McuClient, sender string, streamType StreamType, candidate interface{}) { candidate_message := &AnswerOfferMessage{ To: s.PublicId(), From: sender, Type: "candidate", - RoomType: streamType, + RoomType: string(streamType), Payload: map[string]interface{}{ "candidate": candidate, }, @@ -839,15 +839,15 @@ func (s *ClientSession) IsAllowedToSend(data *MessageClientMessageData) error { } } -func (s *ClientSession) CheckOfferType(streamType string, data *MessageClientMessageData) (MediaType, error) { +func (s *ClientSession) CheckOfferType(streamType StreamType, data *MessageClientMessageData) (MediaType, error) { s.mu.Lock() defer s.mu.Unlock() return s.checkOfferTypeLocked(streamType, data) } -func (s *ClientSession) checkOfferTypeLocked(streamType string, data *MessageClientMessageData) (MediaType, error) { - if streamType == streamTypeScreen { +func (s *ClientSession) checkOfferTypeLocked(streamType StreamType, data *MessageClientMessageData) (MediaType, error) { + if streamType == StreamTypeScreen { if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { return 0, &PermissionError{PERMISSION_MAY_PUBLISH_SCREEN} } @@ -865,7 +865,7 @@ func (s *ClientSession) checkOfferTypeLocked(streamType string, data *MessageCli return 0, nil } -func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, streamType string, data *MessageClientMessageData) (McuPublisher, error) { +func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, streamType StreamType, data *MessageClientMessageData) (McuPublisher, error) { s.mu.Lock() defer s.mu.Unlock() @@ -883,7 +883,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea bitrate := data.Bitrate if backend := s.Backend(); backend != nil { var maxBitrate int - if streamType == streamTypeScreen { + if streamType == StreamTypeScreen { maxBitrate = backend.maxScreenBitrate } else { maxBitrate = backend.maxStreamBitrate @@ -900,7 +900,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea return nil, err } if s.publishers == nil { - s.publishers = make(map[string]McuPublisher) + s.publishers = make(map[StreamType]McuPublisher) } if prev, found := s.publishers[streamType]; found { // Another thread created the publisher while we were waiting. @@ -921,18 +921,18 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea return publisher, nil } -func (s *ClientSession) getPublisherLocked(streamType string) McuPublisher { +func (s *ClientSession) getPublisherLocked(streamType StreamType) McuPublisher { return s.publishers[streamType] } -func (s *ClientSession) GetPublisher(streamType string) McuPublisher { +func (s *ClientSession) GetPublisher(streamType StreamType) McuPublisher { s.mu.Lock() defer s.mu.Unlock() return s.getPublisherLocked(streamType) } -func (s *ClientSession) GetOrWaitForPublisher(ctx context.Context, streamType string) McuPublisher { +func (s *ClientSession) GetOrWaitForPublisher(ctx context.Context, streamType StreamType) McuPublisher { s.mu.Lock() defer s.mu.Unlock() @@ -961,13 +961,13 @@ func (s *ClientSession) GetOrWaitForPublisher(ctx context.Context, streamType st } } -func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id string, streamType string) (McuSubscriber, error) { +func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id string, streamType StreamType) (McuSubscriber, error) { s.mu.Lock() defer s.mu.Unlock() // TODO(jojo): Add method to remove subscribers. - subscriber, found := s.subscribers[id+"|"+streamType] + subscriber, found := s.subscribers[getStreamId(id, streamType)] if !found { s.mu.Unlock() var err error @@ -979,7 +979,7 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s if s.subscribers == nil { s.subscribers = make(map[string]McuSubscriber) } - if prev, found := s.subscribers[id+"|"+streamType]; found { + if prev, found := s.subscribers[getStreamId(id, streamType)]; found { // Another thread created the subscriber while we were waiting. go func(sub McuSubscriber) { closeCtx := context.TODO() @@ -987,7 +987,7 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s }(subscriber) subscriber = prev } else { - s.subscribers[id+"|"+streamType] = subscriber + s.subscribers[getStreamId(id, streamType)] = subscriber } log.Printf("Subscribing %s from %s as %s in session %s", streamType, id, subscriber.Id(), s.PublicId()) } @@ -995,11 +995,11 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s return subscriber, nil } -func (s *ClientSession) GetSubscriber(id string, streamType string) McuSubscriber { +func (s *ClientSession) GetSubscriber(id string, streamType StreamType) McuSubscriber { s.mu.Lock() defer s.mu.Unlock() - return s.subscribers[id+"|"+streamType] + return s.subscribers[getStreamId(id, streamType)] } func (s *ClientSession) ProcessAsyncRoomMessage(message *AsyncMessage) { @@ -1023,10 +1023,10 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { defer s.mu.Unlock() if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_MEDIA) { - if publisher, found := s.publishers[streamTypeVideo]; found { + if publisher, found := s.publishers[StreamTypeVideo]; found { if (publisher.HasMedia(MediaTypeAudio) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_AUDIO)) || (publisher.HasMedia(MediaTypeVideo) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_VIDEO)) { - delete(s.publishers, streamTypeVideo) + delete(s.publishers, StreamTypeVideo) log.Printf("Session %s is no longer allowed to publish media, closing publisher %s", s.PublicId(), publisher.Id()) go func() { publisher.Close(context.Background()) @@ -1036,8 +1036,8 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { } } if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { - if publisher, found := s.publishers[streamTypeScreen]; found { - delete(s.publishers, streamTypeScreen) + if publisher, found := s.publishers[StreamTypeScreen]; found { + delete(s.publishers, StreamTypeScreen) log.Printf("Session %s is no longer allowed to publish screen, closing publisher %s", s.PublicId(), publisher.Id()) go func() { publisher.Close(context.Background()) @@ -1059,7 +1059,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { ctx, cancel := context.WithTimeout(context.Background(), s.hub.mcuTimeout) defer cancel() - mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, message.SendOffer.Data.RoomType) + mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, StreamType(message.SendOffer.Data.RoomType)) if err != nil { log.Printf("Could not create MCU subscriber for session %s to process sendoffer in %s: %s", message.SendOffer.SessionId, s.PublicId(), err) if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{ diff --git a/clientsession_test.go b/clientsession_test.go index 0066eea..39f2531 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -222,16 +222,16 @@ func TestBandwidth_Backend(t *testing.T) { hub.SetMcu(mcu) - streamTypes := []string{ - streamTypeVideo, - streamTypeScreen, + streamTypes := []StreamType{ + StreamTypeVideo, + StreamTypeScreen, } ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() for _, streamType := range streamTypes { - t.Run(streamType, func(t *testing.T) { + t.Run(string(streamType), func(t *testing.T) { client := NewTestClient(t, server, hub) defer client.CloseWithBye() @@ -268,7 +268,7 @@ func TestBandwidth_Backend(t *testing.T) { }, MessageClientMessageData{ Type: "offer", Sid: "54321", - RoomType: streamType, + RoomType: string(streamType), Bitrate: bitrate, Payload: map[string]interface{}{ "sdp": MockSdpOfferAudioAndVideo, @@ -287,7 +287,7 @@ func TestBandwidth_Backend(t *testing.T) { } var expectBitrate int - if streamType == streamTypeVideo { + if streamType == StreamTypeVideo { expectBitrate = backend.maxStreamBitrate } else { expectBitrate = backend.maxScreenBitrate diff --git a/grpc_client.go b/grpc_client.go index b2a1855..8d50226 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -223,13 +223,13 @@ func (c *GrpcClient) IsSessionInCall(ctx context.Context, sessionId string, room return response.GetInCall(), nil } -func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, streamType string) (string, string, net.IP, error) { +func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, streamType StreamType) (string, string, net.IP, error) { statsGrpcClientCalls.WithLabelValues("GetPublisherId").Inc() // TODO: Remove debug logging log.Printf("Get %s publisher id %s on %s", streamType, sessionId, c.Target()) response, err := c.impl.GetPublisherId(ctx, &GetPublisherIdRequest{ SessionId: sessionId, - StreamType: streamType, + StreamType: string(streamType), }, grpc.WaitForReady(true)) if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { return "", "", nil, nil diff --git a/grpc_server.go b/grpc_server.go index 97b6368..3108be9 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -171,7 +171,7 @@ func (s *GrpcServer) GetPublisherId(ctx context.Context, request *GetPublisherId return nil, status.Error(codes.NotFound, "no such session") } - publisher := clientSession.GetOrWaitForPublisher(ctx, request.StreamType) + publisher := clientSession.GetOrWaitForPublisher(ctx, StreamType(request.StreamType)) if publisher, ok := publisher.(*mcuProxyPublisher); ok { reply := &GetPublisherIdReply{ PublisherId: publisher.Id(), diff --git a/hub.go b/hub.go index 7a01e1f..d8349c4 100644 --- a/hub.go +++ b/hub.go @@ -1445,6 +1445,16 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { // Maybe this is a message to be processed by the MCU. var data MessageClientMessageData if err := json.Unmarshal(*msg.Data, &data); err == nil { + if err := data.CheckValid(); err != nil { + log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + if err, ok := err.(*Error); ok { + session.SendMessage(message.NewErrorServerMessage(err)) + } else { + session.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + } + return + } + clientData = &data switch clientData.Type { @@ -1476,7 +1486,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { return } - publisher := session.GetPublisher(streamTypeScreen) + publisher := session.GetPublisher(StreamTypeScreen) if publisher == nil { return } @@ -1547,6 +1557,16 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { if h.mcu != nil { var data MessageClientMessageData if err := json.Unmarshal(*msg.Data, &data); err == nil { + if err := data.CheckValid(); err != nil { + log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + if err, ok := err.(*Error); ok { + session.SendMessage(message.NewErrorServerMessage(err)) + } else { + session.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + } + return + } + clientData = &data } } @@ -1586,7 +1606,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout) defer cancel() - mc, err := recipient.GetOrCreateSubscriber(ctx, h.mcu, session.PublicId(), clientData.RoomType) + mc, err := recipient.GetOrCreateSubscriber(ctx, h.mcu, session.PublicId(), StreamType(clientData.RoomType)) if err != nil { log.Printf("Could not create MCU subscriber for session %s to send %+v to %s: %s", session.PublicId(), clientData, recipient.PublicId(), err) sendMcuClientNotFound(session, message) @@ -2145,13 +2165,13 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe } clientType = "subscriber" - mc, err = session.GetOrCreateSubscriber(ctx, h.mcu, message.Recipient.SessionId, data.RoomType) + mc, err = session.GetOrCreateSubscriber(ctx, h.mcu, message.Recipient.SessionId, StreamType(data.RoomType)) case "sendoffer": // Will be sent directly. return case "offer": clientType = "publisher" - mc, err = session.GetOrCreatePublisher(ctx, h.mcu, data.RoomType, data) + mc, err = session.GetOrCreatePublisher(ctx, h.mcu, StreamType(data.RoomType), data) if err, ok := err.(*PermissionError); ok { log.Printf("Session %s is not allowed to offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err) sendNotAllowed(session, client_message, "Not allowed to publish.") @@ -2169,7 +2189,7 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe } clientType = "subscriber" - mc = session.GetSubscriber(message.Recipient.SessionId, data.RoomType) + mc = session.GetSubscriber(message.Recipient.SessionId, StreamType(data.RoomType)) default: if session.PublicId() == message.Recipient.SessionId { if err := session.IsAllowedToSend(data); err != nil { @@ -2179,10 +2199,10 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe } clientType = "publisher" - mc = session.GetPublisher(data.RoomType) + mc = session.GetPublisher(StreamType(data.RoomType)) } else { clientType = "subscriber" - mc = session.GetSubscriber(message.Recipient.SessionId, data.RoomType) + mc = session.GetSubscriber(message.Recipient.SessionId, StreamType(data.RoomType)) } } if err != nil { diff --git a/mcu_common.go b/mcu_common.go index 9824443..6fe48c0 100644 --- a/mcu_common.go +++ b/mcu_common.go @@ -75,14 +75,35 @@ type Mcu interface { GetStats() interface{} - NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) - NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) + NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) + NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) +} + +type StreamType string + +const ( + StreamTypeAudio StreamType = "audio" + StreamTypeVideo StreamType = "video" + StreamTypeScreen StreamType = "screen" +) + +func IsValidStreamType(s string) bool { + switch s { + case string(StreamTypeAudio): + fallthrough + case string(StreamTypeVideo): + fallthrough + case string(StreamTypeScreen): + return true + default: + return false + } } type McuClient interface { Id() string Sid() string - StreamType() string + StreamType() StreamType Close(ctx context.Context) diff --git a/mcu_janus.go b/mcu_janus.go index e7a1e39..afe247b 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -50,18 +50,19 @@ const ( defaultMaxStreamBitrate = 1024 * 1024 defaultMaxScreenBitrate = 2048 * 1024 - - streamTypeVideo = "video" - streamTypeScreen = "screen" ) var ( - streamTypeUserIds = map[string]uint64{ - streamTypeVideo: videoPublisherUserId, - streamTypeScreen: screenPublisherUserId, + streamTypeUserIds = map[StreamType]uint64{ + StreamTypeVideo: videoPublisherUserId, + StreamTypeScreen: screenPublisherUserId, } ) +func getStreamId(publisherId string, streamType StreamType) string { + return fmt.Sprintf("%s|%s", publisherId, streamType) +} + func getPluginValue(data janus.PluginData, pluginName string, key string) interface{} { if data.Plugin != pluginName { return nil @@ -436,7 +437,7 @@ type mcuJanusClient struct { session uint64 roomId uint64 sid string - streamType string + streamType StreamType handle *JanusHandle handleId uint64 @@ -459,7 +460,7 @@ func (c *mcuJanusClient) Sid() string { return c.sid } -func (c *mcuJanusClient) StreamType() string { +func (c *mcuJanusClient) StreamType() StreamType { return c.streamType } @@ -609,7 +610,7 @@ func (c *mcuJanusClient) selectStream(ctx context.Context, stream *streamSelecti type publisherStatsCounter struct { mu sync.Mutex - streamTypes map[string]bool + streamTypes map[StreamType]bool subscribers map[string]bool } @@ -619,14 +620,14 @@ func (c *publisherStatsCounter) Reset() { count := len(c.subscribers) for streamType := range c.streamTypes { - statsMcuPublisherStreamTypesCurrent.WithLabelValues(streamType).Dec() - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Sub(float64(count)) + statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(count)) } c.streamTypes = nil c.subscribers = nil } -func (c *publisherStatsCounter) EnableStream(streamType string, enable bool) { +func (c *publisherStatsCounter) EnableStream(streamType StreamType, enable bool) { c.mu.Lock() defer c.mu.Unlock() @@ -636,15 +637,15 @@ func (c *publisherStatsCounter) EnableStream(streamType string, enable bool) { if enable { if c.streamTypes == nil { - c.streamTypes = make(map[string]bool) + c.streamTypes = make(map[StreamType]bool) } c.streamTypes[streamType] = true - statsMcuPublisherStreamTypesCurrent.WithLabelValues(streamType).Inc() - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Add(float64(len(c.subscribers))) + statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Inc() + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Add(float64(len(c.subscribers))) } else { delete(c.streamTypes, streamType) - statsMcuPublisherStreamTypesCurrent.WithLabelValues(streamType).Dec() - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Sub(float64(len(c.subscribers))) + statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(len(c.subscribers))) } } @@ -661,7 +662,7 @@ func (c *publisherStatsCounter) AddSubscriber(id string) { } c.subscribers[id] = true for streamType := range c.streamTypes { - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Inc() + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Inc() } } @@ -675,7 +676,7 @@ func (c *publisherStatsCounter) RemoveSubscriber(id string) { delete(c.subscribers, id) for streamType := range c.streamTypes { - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Dec() + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() } } @@ -688,20 +689,20 @@ type mcuJanusPublisher struct { stats publisherStatsCounter } -func (m *mcuJanus) SubscriberConnected(id string, publisher string, streamType string) { +func (m *mcuJanus) SubscriberConnected(id string, publisher string, streamType StreamType) { m.mu.Lock() defer m.mu.Unlock() - if p, found := m.publishers[publisher+"|"+streamType]; found { + if p, found := m.publishers[getStreamId(publisher, streamType)]; found { p.stats.AddSubscriber(id) } } -func (m *mcuJanus) SubscriberDisconnected(id string, publisher string, streamType string) { +func (m *mcuJanus) SubscriberDisconnected(id string, publisher string, streamType StreamType) { m.mu.Lock() defer m.mu.Unlock() - if p, found := m.publishers[publisher+"|"+streamType]; found { + if p, found := m.publishers[getStreamId(publisher, streamType)]; found { p.stats.RemoveSubscriber(id) } } @@ -714,7 +715,7 @@ func min(a, b int) int { return b } -func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType string, bitrate int) (*JanusHandle, uint64, uint64, error) { +func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType StreamType, bitrate int) (*JanusHandle, uint64, uint64, error) { session := m.session if session == nil { return nil, 0, 0, ErrNotConnected @@ -727,7 +728,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id) create_msg := map[string]interface{}{ "request": "create", - "description": id + "|" + streamType, + "description": getStreamId(id, streamType), // We publish every stream in its own Janus room. "publishers": 1, // Do not use the video-orientation RTP extension as it breaks video @@ -735,7 +736,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st "videoorient_ext": false, } var maxBitrate int - if streamType == streamTypeScreen { + if streamType == StreamTypeScreen { maxBitrate = m.maxScreenBitrate } else { maxBitrate = m.maxStreamBitrate @@ -782,7 +783,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st return handle, response.Session, roomId, nil } -func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { +func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { if _, found := streamTypeUserIds[streamType]; !found { return nil, fmt.Errorf("Unsupported stream type %s", streamType) } @@ -823,11 +824,11 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id st log.Printf("Publisher %s is using handle %d", client.id, client.handleId) go client.run(handle, client.closeChan) m.mu.Lock() - m.publishers[id+"|"+streamType] = client - m.publisherCreated.Notify(id + "|" + streamType) + m.publishers[getStreamId(id, streamType)] = client + m.publisherCreated.Notify(getStreamId(id, streamType)) m.mu.Unlock() - statsPublishersCurrent.WithLabelValues(streamType).Inc() - statsPublishersTotal.WithLabelValues(streamType).Inc() + statsPublishersCurrent.WithLabelValues(string(streamType)).Inc() + statsPublishersTotal.WithLabelValues(string(streamType)).Inc() return client, nil } @@ -860,7 +861,7 @@ func (p *mcuJanusPublisher) handleDetached(event *janus.DetachedMsg) { func (p *mcuJanusPublisher) handleConnected(event *janus.WebRTCUpMsg) { log.Printf("Publisher %d received connected", p.handleId) - p.mcu.publisherConnected.Notify(p.id + "|" + p.streamType) + p.mcu.publisherConnected.Notify(getStreamId(p.id, p.streamType)) } func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) { @@ -872,8 +873,8 @@ func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) { } func (p *mcuJanusPublisher) handleMedia(event *janus.MediaMsg) { - mediaType := event.Type - if mediaType == "video" && p.streamType == "screen" { + mediaType := StreamType(event.Type) + if mediaType == StreamTypeVideo && p.streamType == StreamTypeScreen { // We want to differentiate between audio, video and screensharing mediaType = p.streamType } @@ -920,7 +921,7 @@ func (p *mcuJanusPublisher) Close(ctx context.Context) { log.Printf("Room %d destroyed", p.roomId) } p.mcu.mu.Lock() - delete(p.mcu.publishers, p.id+"|"+p.streamType) + delete(p.mcu.publishers, getStreamId(p.id, p.streamType)) p.mcu.mu.Unlock() p.roomId = 0 notify = true @@ -931,7 +932,7 @@ func (p *mcuJanusPublisher) Close(ctx context.Context) { p.stats.Reset() if notify { - statsPublishersCurrent.WithLabelValues(p.streamType).Dec() + statsPublishersCurrent.WithLabelValues(string(p.streamType)).Dec() p.mcu.unregisterClient(p) p.listener.PublisherClosed(p) } @@ -975,9 +976,9 @@ type mcuJanusSubscriber struct { publisher string } -func (m *mcuJanus) getPublisher(ctx context.Context, publisher string, streamType string) (*mcuJanusPublisher, error) { +func (m *mcuJanus) getPublisher(ctx context.Context, publisher string, streamType StreamType) (*mcuJanusPublisher, error) { // Do the direct check immediately as this should be the normal case. - key := publisher + "|" + streamType + key := getStreamId(publisher, streamType) m.mu.Lock() if result, found := m.publishers[key]; found { m.mu.Unlock() @@ -1002,7 +1003,7 @@ func (m *mcuJanus) getPublisher(ctx context.Context, publisher string, streamTyp } } -func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher string, streamType string) (*JanusHandle, *mcuJanusPublisher, error) { +func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher string, streamType StreamType) (*JanusHandle, *mcuJanusPublisher, error) { var pub *mcuJanusPublisher var err error if pub, err = m.getPublisher(ctx, publisher, streamType); err != nil { @@ -1023,7 +1024,7 @@ func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher st return handle, pub, nil } -func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { +func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { if _, found := streamTypeUserIds[streamType]; !found { return nil, fmt.Errorf("Unsupported stream type %s", streamType) } @@ -1058,8 +1059,8 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ client.mcuJanusClient.handleMedia = client.handleMedia m.registerClient(client) go client.run(handle, client.closeChan) - statsSubscribersCurrent.WithLabelValues(streamType).Inc() - statsSubscribersTotal.WithLabelValues(streamType).Inc() + statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc() + statsSubscribersTotal.WithLabelValues(string(streamType)).Inc() return client, nil } @@ -1144,7 +1145,7 @@ func (p *mcuJanusSubscriber) Close(ctx context.Context) { if closed { p.mcu.SubscriberDisconnected(p.Id(), p.publisher, p.streamType) - statsSubscribersCurrent.WithLabelValues(p.streamType).Dec() + statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Dec() } p.mcu.unregisterClient(p) p.listener.SubscriberClosed(p) @@ -1158,7 +1159,7 @@ func (p *mcuJanusSubscriber) joinRoom(ctx context.Context, stream *streamSelecti return } - waiter := p.mcu.publisherConnected.NewWaiter(p.publisher + "|" + p.streamType) + waiter := p.mcu.publisherConnected.NewWaiter(getStreamId(p.publisher, p.streamType)) defer p.mcu.publisherConnected.Release(waiter) loggedNotPublishingYet := false @@ -1223,7 +1224,7 @@ retry: if !loggedNotPublishingYet { loggedNotPublishingYet = true - statsWaitingForPublisherTotal.WithLabelValues(p.streamType).Inc() + statsWaitingForPublisherTotal.WithLabelValues(string(p.streamType)).Inc() } if err := waiter.Wait(ctx); err != nil { diff --git a/mcu_proxy.go b/mcu_proxy.go index 9186a64..eeff1de 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -76,7 +76,7 @@ type McuProxy interface { type mcuProxyPubSubCommon struct { sid string - streamType string + streamType StreamType proxyId string conn *mcuProxyConnection listener McuListener @@ -90,7 +90,7 @@ func (c *mcuProxyPubSubCommon) Sid() string { return c.sid } -func (c *mcuProxyPubSubCommon) StreamType() string { +func (c *mcuProxyPubSubCommon) StreamType() StreamType { return c.streamType } @@ -132,7 +132,7 @@ type mcuProxyPublisher struct { mediaTypes MediaType } -func newMcuProxyPublisher(id string, sid string, streamType string, mediaTypes MediaType, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { +func newMcuProxyPublisher(id string, sid string, streamType StreamType, mediaTypes MediaType, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { return &mcuProxyPublisher{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{ sid: sid, @@ -217,7 +217,7 @@ type mcuProxySubscriber struct { publisherId string } -func newMcuProxySubscriber(publisherId string, sid string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber { +func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber { return &mcuProxySubscriber{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{ sid: sid, @@ -719,9 +719,9 @@ func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) { if _, found := c.publishers[publisher.proxyId]; found { delete(c.publishers, publisher.proxyId) - statsPublishersCurrent.WithLabelValues(publisher.StreamType()).Dec() + statsPublishersCurrent.WithLabelValues(string(publisher.StreamType())).Dec() } - delete(c.publisherIds, publisher.id+"|"+publisher.StreamType()) + delete(c.publisherIds, getStreamId(publisher.id, publisher.StreamType())) if len(c.publishers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) { go c.closeIfEmpty() @@ -751,7 +751,7 @@ func (c *mcuProxyConnection) removeSubscriber(subscriber *mcuProxySubscriber) { if _, found := c.subscribers[subscriber.proxyId]; found { delete(c.subscribers, subscriber.proxyId) - statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec() + statsSubscribersCurrent.WithLabelValues(string(subscriber.StreamType())).Dec() } if len(c.subscribers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) { @@ -1032,7 +1032,7 @@ func (c *mcuProxyConnection) performSyncRequest(ctx context.Context, msg *ProxyC } } -func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType) (McuPublisher, error) { +func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType) (McuPublisher, error) { msg := &ProxyClientMessage{ Type: "command", Command: &CommandProxyClientMessage{ @@ -1057,14 +1057,14 @@ func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListe publisher := newMcuProxyPublisher(id, sid, streamType, mediaTypes, proxyId, c, listener) c.publishersLock.Lock() c.publishers[proxyId] = publisher - c.publisherIds[id+"|"+streamType] = proxyId + c.publisherIds[getStreamId(id, streamType)] = proxyId c.publishersLock.Unlock() - statsPublishersCurrent.WithLabelValues(streamType).Inc() - statsPublishersTotal.WithLabelValues(streamType).Inc() + statsPublishersCurrent.WithLabelValues(string(streamType)).Inc() + statsPublishersTotal.WithLabelValues(string(streamType)).Inc() return publisher, nil } -func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType string) (McuSubscriber, error) { +func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType StreamType) (McuSubscriber, error) { msg := &ProxyClientMessage{ Type: "command", Command: &CommandProxyClientMessage{ @@ -1088,8 +1088,8 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList c.subscribersLock.Lock() c.subscribers[proxyId] = subscriber c.subscribersLock.Unlock() - statsSubscribersCurrent.WithLabelValues(streamType).Inc() - statsSubscribersTotal.WithLabelValues(streamType).Inc() + statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc() + statsSubscribersTotal.WithLabelValues(string(streamType)).Inc() return subscriber, nil } @@ -1555,10 +1555,10 @@ func (m *mcuProxy) removePublisher(publisher *mcuProxyPublisher) { m.mu.Lock() defer m.mu.Unlock() - delete(m.publishers, publisher.id+"|"+publisher.StreamType()) + delete(m.publishers, getStreamId(publisher.id, publisher.StreamType())) } -func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { +func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { connections := m.getSortedConnections(initiator) for _, conn := range connections { if conn.IsShutdownScheduled() || conn.IsTemporary() { @@ -1569,7 +1569,7 @@ func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id st defer cancel() var maxBitrate int - if streamType == streamTypeScreen { + if streamType == StreamTypeScreen { maxBitrate = m.maxScreenBitrate } else { maxBitrate = m.maxStreamBitrate @@ -1586,28 +1586,28 @@ func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id st } m.mu.Lock() - m.publishers[id+"|"+streamType] = conn + m.publishers[getStreamId(id, streamType)] = conn m.mu.Unlock() m.publisherWaiters.Wakeup() return publisher, nil } - statsProxyNobackendAvailableTotal.WithLabelValues(streamType).Inc() + statsProxyNobackendAvailableTotal.WithLabelValues(string(streamType)).Inc() return nil, fmt.Errorf("No MCU connection available") } -func (m *mcuProxy) getPublisherConnection(publisher string, streamType string) *mcuProxyConnection { +func (m *mcuProxy) getPublisherConnection(publisher string, streamType StreamType) *mcuProxyConnection { m.mu.RLock() defer m.mu.RUnlock() - return m.publishers[publisher+"|"+streamType] + return m.publishers[getStreamId(publisher, streamType)] } -func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher string, streamType string) *mcuProxyConnection { +func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher string, streamType StreamType) *mcuProxyConnection { m.mu.Lock() defer m.mu.Unlock() - conn := m.publishers[publisher+"|"+streamType] + conn := m.publishers[getStreamId(publisher, streamType)] if conn != nil { // Publisher was created while waiting for lock. return conn @@ -1617,13 +1617,13 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str id := m.publisherWaiters.Add(ch) defer m.publisherWaiters.Remove(id) - statsWaitingForPublisherTotal.WithLabelValues(streamType).Inc() + statsWaitingForPublisherTotal.WithLabelValues(string(streamType)).Inc() for { m.mu.Unlock() select { case <-ch: m.mu.Lock() - conn = m.publishers[publisher+"|"+streamType] + conn = m.publishers[getStreamId(publisher, streamType)] if conn != nil { return conn } @@ -1634,11 +1634,11 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str } } -func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { +func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { if conn := m.getPublisherConnection(publisher, streamType); conn != nil { // Fast common path: publisher is available locally. conn.publishersLock.Lock() - id, found := conn.publisherIds[publisher+"|"+streamType] + id, found := conn.publisherIds[getStreamId(publisher, streamType)] conn.publishersLock.Unlock() if !found { return nil, fmt.Errorf("Unknown publisher %s", publisher) @@ -1658,7 +1658,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ cancel() // Cancel pending RPC calls. conn.publishersLock.Lock() - id, found := conn.publisherIds[publisher+"|"+streamType] + id, found := conn.publisherIds[getStreamId(publisher, streamType)] conn.publishersLock.Unlock() if !found { log.Printf("Unknown id for local %s publisher %s", streamType, publisher) diff --git a/mcu_test.go b/mcu_test.go index 1cc88b1..42de651 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -69,9 +69,9 @@ func (m *TestMCU) GetStats() interface{} { return nil } -func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { +func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { var maxBitrate int - if streamType == streamTypeScreen { + if streamType == StreamTypeScreen { maxBitrate = TestMaxBitrateScreen } else { maxBitrate = TestMaxBitrateVideo @@ -117,7 +117,7 @@ func (m *TestMCU) GetPublisher(id string) *TestMCUPublisher { return m.publishers[id] } -func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { +func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { m.mu.Lock() defer m.mu.Unlock() @@ -143,7 +143,7 @@ type TestMCUClient struct { id string sid string - streamType string + streamType StreamType } func (c *TestMCUClient) Id() string { @@ -154,7 +154,7 @@ func (c *TestMCUClient) Sid() string { return c.sid } -func (c *TestMCUClient) StreamType() string { +func (c *TestMCUClient) StreamType() StreamType { return c.streamType } diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index d2ef29a..f950b1a 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -657,8 +657,8 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s }, } session.sendMessage(response) - statsPublishersCurrent.WithLabelValues(cmd.StreamType).Inc() - statsPublishersTotal.WithLabelValues(cmd.StreamType).Inc() + statsPublishersCurrent.WithLabelValues(string(cmd.StreamType)).Inc() + statsPublishersTotal.WithLabelValues(string(cmd.StreamType)).Inc() case "create-subscriber": id := uuid.New().String() publisherId := cmd.PublisherId @@ -686,8 +686,8 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s }, } session.sendMessage(response) - statsSubscribersCurrent.WithLabelValues(cmd.StreamType).Inc() - statsSubscribersTotal.WithLabelValues(cmd.StreamType).Inc() + statsSubscribersCurrent.WithLabelValues(string(cmd.StreamType)).Inc() + statsSubscribersTotal.WithLabelValues(string(cmd.StreamType)).Inc() case "delete-publisher": client := s.GetClient(cmd.ClientId) if client == nil { @@ -707,7 +707,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } if s.DeleteClient(cmd.ClientId, client) { - statsPublishersCurrent.WithLabelValues(client.StreamType()).Dec() + statsPublishersCurrent.WithLabelValues(string(client.StreamType())).Dec() } go func() { @@ -742,7 +742,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } if s.DeleteClient(cmd.ClientId, client) { - statsSubscribersCurrent.WithLabelValues(client.StreamType()).Dec() + statsSubscribersCurrent.WithLabelValues(string(client.StreamType())).Dec() } go func() { diff --git a/proxy/proxy_session.go b/proxy/proxy_session.go index 80445b2..5c476e0 100644 --- a/proxy/proxy_session.go +++ b/proxy/proxy_session.go @@ -212,7 +212,7 @@ func (s *ProxySession) SubscriberSidUpdated(subscriber signaling.McuSubscriber) func (s *ProxySession) PublisherClosed(publisher signaling.McuPublisher) { if id := s.DeletePublisher(publisher); id != "" { if s.proxy.DeleteClient(id, publisher) { - statsPublishersCurrent.WithLabelValues(publisher.StreamType()).Dec() + statsPublishersCurrent.WithLabelValues(string(publisher.StreamType())).Dec() } msg := &signaling.ProxyServerMessage{ @@ -229,7 +229,7 @@ func (s *ProxySession) PublisherClosed(publisher signaling.McuPublisher) { func (s *ProxySession) SubscriberClosed(subscriber signaling.McuSubscriber) { if id := s.DeleteSubscriber(subscriber); id != "" { if s.proxy.DeleteClient(id, subscriber) { - statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec() + statsSubscribersCurrent.WithLabelValues(string(subscriber.StreamType())).Dec() } msg := &signaling.ProxyServerMessage{ @@ -294,7 +294,7 @@ func (s *ProxySession) clearPublishers() { go func(publishers map[string]signaling.McuPublisher) { for id, publisher := range publishers { if s.proxy.DeleteClient(id, publisher) { - statsPublishersCurrent.WithLabelValues(publisher.StreamType()).Dec() + statsPublishersCurrent.WithLabelValues(string(publisher.StreamType())).Dec() } publisher.Close(context.Background()) } @@ -310,7 +310,7 @@ func (s *ProxySession) clearSubscribers() { go func(subscribers map[string]signaling.McuSubscriber) { for id, subscriber := range subscribers { if s.proxy.DeleteClient(id, subscriber) { - statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec() + statsSubscribersCurrent.WithLabelValues(string(subscriber.StreamType())).Dec() } subscriber.Close(context.Background()) }