Strongly type "StreamType".

This commit is contained in:
Joachim Bauch 2024-02-27 13:52:59 +01:00
parent 26a65cedd1
commit 7d09c71ab9
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
13 changed files with 188 additions and 139 deletions

View file

@ -179,12 +179,12 @@ type ByeProxyServerMessage struct {
type CommandProxyClientMessage struct { type CommandProxyClientMessage struct {
Type string `json:"type"` Type string `json:"type"`
Sid string `json:"sid,omitempty"` Sid string `json:"sid,omitempty"`
StreamType string `json:"streamType,omitempty"` StreamType StreamType `json:"streamType,omitempty"`
PublisherId string `json:"publisherId,omitempty"` PublisherId string `json:"publisherId,omitempty"`
ClientId string `json:"clientId,omitempty"` ClientId string `json:"clientId,omitempty"`
Bitrate int `json:"bitrate,omitempty"` Bitrate int `json:"bitrate,omitempty"`
MediaTypes MediaType `json:"mediatypes,omitempty"` MediaTypes MediaType `json:"mediatypes,omitempty"`
} }
func (m *CommandProxyClientMessage) CheckValid() error { func (m *CommandProxyClientMessage) CheckValid() error {

View file

@ -565,6 +565,13 @@ type MessageClientMessageData struct {
Payload map[string]interface{} `json:"payload"` Payload map[string]interface{} `json:"payload"`
} }
func (m *MessageClientMessageData) CheckValid() error {
if !IsValidStreamType(m.RoomType) {
return fmt.Errorf("invalid room type: %s", m.RoomType)
}
return nil
}
func (m *MessageClientMessage) CheckValid() error { func (m *MessageClientMessage) CheckValid() error {
if m.Data == nil || len(*m.Data) == 0 { if m.Data == nil || len(*m.Data) == 0 {
return fmt.Errorf("message empty") return fmt.Errorf("message empty")

View file

@ -79,7 +79,7 @@ type ClientSession struct {
publisherWaiters ChannelWaiters publisherWaiters ChannelWaiters
publishers map[string]McuPublisher publishers map[StreamType]McuPublisher
subscribers map[string]McuSubscriber subscribers map[string]McuSubscriber
pendingClientMessages []*ServerMessage pendingClientMessages []*ServerMessage
@ -356,7 +356,7 @@ func (s *ClientSession) getRoomJoinTime() time.Time {
func (s *ClientSession) releaseMcuObjects() { func (s *ClientSession) releaseMcuObjects() {
if len(s.publishers) > 0 { if len(s.publishers) > 0 {
go func(publishers map[string]McuPublisher) { go func(publishers map[StreamType]McuPublisher) {
ctx := context.TODO() ctx := context.TODO()
for _, publisher := range publishers { for _, publisher := range publishers {
publisher.Close(ctx) publisher.Close(ctx)
@ -573,12 +573,12 @@ func (s *ClientSession) SetClient(client *Client) *Client {
return prev return prev
} }
func (s *ClientSession) sendOffer(client McuClient, sender string, streamType string, offer map[string]interface{}) { func (s *ClientSession) sendOffer(client McuClient, sender string, streamType StreamType, offer map[string]interface{}) {
offer_message := &AnswerOfferMessage{ offer_message := &AnswerOfferMessage{
To: s.PublicId(), To: s.PublicId(),
From: sender, From: sender,
Type: "offer", Type: "offer",
RoomType: streamType, RoomType: string(streamType),
Payload: offer, Payload: offer,
Sid: client.Sid(), Sid: client.Sid(),
} }
@ -601,12 +601,12 @@ func (s *ClientSession) sendOffer(client McuClient, sender string, streamType st
s.sendMessageUnlocked(response_message) s.sendMessageUnlocked(response_message)
} }
func (s *ClientSession) sendCandidate(client McuClient, sender string, streamType string, candidate interface{}) { func (s *ClientSession) sendCandidate(client McuClient, sender string, streamType StreamType, candidate interface{}) {
candidate_message := &AnswerOfferMessage{ candidate_message := &AnswerOfferMessage{
To: s.PublicId(), To: s.PublicId(),
From: sender, From: sender,
Type: "candidate", Type: "candidate",
RoomType: streamType, RoomType: string(streamType),
Payload: map[string]interface{}{ Payload: map[string]interface{}{
"candidate": candidate, "candidate": candidate,
}, },
@ -839,15 +839,15 @@ func (s *ClientSession) IsAllowedToSend(data *MessageClientMessageData) error {
} }
} }
func (s *ClientSession) CheckOfferType(streamType string, data *MessageClientMessageData) (MediaType, error) { func (s *ClientSession) CheckOfferType(streamType StreamType, data *MessageClientMessageData) (MediaType, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
return s.checkOfferTypeLocked(streamType, data) return s.checkOfferTypeLocked(streamType, data)
} }
func (s *ClientSession) checkOfferTypeLocked(streamType string, data *MessageClientMessageData) (MediaType, error) { func (s *ClientSession) checkOfferTypeLocked(streamType StreamType, data *MessageClientMessageData) (MediaType, error) {
if streamType == streamTypeScreen { if streamType == StreamTypeScreen {
if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) {
return 0, &PermissionError{PERMISSION_MAY_PUBLISH_SCREEN} return 0, &PermissionError{PERMISSION_MAY_PUBLISH_SCREEN}
} }
@ -865,7 +865,7 @@ func (s *ClientSession) checkOfferTypeLocked(streamType string, data *MessageCli
return 0, nil return 0, nil
} }
func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, streamType string, data *MessageClientMessageData) (McuPublisher, error) { func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, streamType StreamType, data *MessageClientMessageData) (McuPublisher, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -883,7 +883,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
bitrate := data.Bitrate bitrate := data.Bitrate
if backend := s.Backend(); backend != nil { if backend := s.Backend(); backend != nil {
var maxBitrate int var maxBitrate int
if streamType == streamTypeScreen { if streamType == StreamTypeScreen {
maxBitrate = backend.maxScreenBitrate maxBitrate = backend.maxScreenBitrate
} else { } else {
maxBitrate = backend.maxStreamBitrate maxBitrate = backend.maxStreamBitrate
@ -900,7 +900,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
return nil, err return nil, err
} }
if s.publishers == nil { if s.publishers == nil {
s.publishers = make(map[string]McuPublisher) s.publishers = make(map[StreamType]McuPublisher)
} }
if prev, found := s.publishers[streamType]; found { if prev, found := s.publishers[streamType]; found {
// Another thread created the publisher while we were waiting. // Another thread created the publisher while we were waiting.
@ -921,18 +921,18 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
return publisher, nil return publisher, nil
} }
func (s *ClientSession) getPublisherLocked(streamType string) McuPublisher { func (s *ClientSession) getPublisherLocked(streamType StreamType) McuPublisher {
return s.publishers[streamType] return s.publishers[streamType]
} }
func (s *ClientSession) GetPublisher(streamType string) McuPublisher { func (s *ClientSession) GetPublisher(streamType StreamType) McuPublisher {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
return s.getPublisherLocked(streamType) return s.getPublisherLocked(streamType)
} }
func (s *ClientSession) GetOrWaitForPublisher(ctx context.Context, streamType string) McuPublisher { func (s *ClientSession) GetOrWaitForPublisher(ctx context.Context, streamType StreamType) McuPublisher {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -961,13 +961,13 @@ func (s *ClientSession) GetOrWaitForPublisher(ctx context.Context, streamType st
} }
} }
func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id string, streamType string) (McuSubscriber, error) { func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id string, streamType StreamType) (McuSubscriber, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// TODO(jojo): Add method to remove subscribers. // TODO(jojo): Add method to remove subscribers.
subscriber, found := s.subscribers[id+"|"+streamType] subscriber, found := s.subscribers[getStreamId(id, streamType)]
if !found { if !found {
s.mu.Unlock() s.mu.Unlock()
var err error var err error
@ -979,7 +979,7 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s
if s.subscribers == nil { if s.subscribers == nil {
s.subscribers = make(map[string]McuSubscriber) s.subscribers = make(map[string]McuSubscriber)
} }
if prev, found := s.subscribers[id+"|"+streamType]; found { if prev, found := s.subscribers[getStreamId(id, streamType)]; found {
// Another thread created the subscriber while we were waiting. // Another thread created the subscriber while we were waiting.
go func(sub McuSubscriber) { go func(sub McuSubscriber) {
closeCtx := context.TODO() closeCtx := context.TODO()
@ -987,7 +987,7 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s
}(subscriber) }(subscriber)
subscriber = prev subscriber = prev
} else { } else {
s.subscribers[id+"|"+streamType] = subscriber s.subscribers[getStreamId(id, streamType)] = subscriber
} }
log.Printf("Subscribing %s from %s as %s in session %s", streamType, id, subscriber.Id(), s.PublicId()) log.Printf("Subscribing %s from %s as %s in session %s", streamType, id, subscriber.Id(), s.PublicId())
} }
@ -995,11 +995,11 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s
return subscriber, nil return subscriber, nil
} }
func (s *ClientSession) GetSubscriber(id string, streamType string) McuSubscriber { func (s *ClientSession) GetSubscriber(id string, streamType StreamType) McuSubscriber {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
return s.subscribers[id+"|"+streamType] return s.subscribers[getStreamId(id, streamType)]
} }
func (s *ClientSession) ProcessAsyncRoomMessage(message *AsyncMessage) { func (s *ClientSession) ProcessAsyncRoomMessage(message *AsyncMessage) {
@ -1023,10 +1023,10 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
defer s.mu.Unlock() defer s.mu.Unlock()
if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_MEDIA) { if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_MEDIA) {
if publisher, found := s.publishers[streamTypeVideo]; found { if publisher, found := s.publishers[StreamTypeVideo]; found {
if (publisher.HasMedia(MediaTypeAudio) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_AUDIO)) || if (publisher.HasMedia(MediaTypeAudio) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_AUDIO)) ||
(publisher.HasMedia(MediaTypeVideo) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_VIDEO)) { (publisher.HasMedia(MediaTypeVideo) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_VIDEO)) {
delete(s.publishers, streamTypeVideo) delete(s.publishers, StreamTypeVideo)
log.Printf("Session %s is no longer allowed to publish media, closing publisher %s", s.PublicId(), publisher.Id()) log.Printf("Session %s is no longer allowed to publish media, closing publisher %s", s.PublicId(), publisher.Id())
go func() { go func() {
publisher.Close(context.Background()) publisher.Close(context.Background())
@ -1036,8 +1036,8 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
} }
} }
if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) {
if publisher, found := s.publishers[streamTypeScreen]; found { if publisher, found := s.publishers[StreamTypeScreen]; found {
delete(s.publishers, streamTypeScreen) delete(s.publishers, StreamTypeScreen)
log.Printf("Session %s is no longer allowed to publish screen, closing publisher %s", s.PublicId(), publisher.Id()) log.Printf("Session %s is no longer allowed to publish screen, closing publisher %s", s.PublicId(), publisher.Id())
go func() { go func() {
publisher.Close(context.Background()) publisher.Close(context.Background())
@ -1059,7 +1059,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
ctx, cancel := context.WithTimeout(context.Background(), s.hub.mcuTimeout) ctx, cancel := context.WithTimeout(context.Background(), s.hub.mcuTimeout)
defer cancel() defer cancel()
mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, message.SendOffer.Data.RoomType) mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, StreamType(message.SendOffer.Data.RoomType))
if err != nil { if err != nil {
log.Printf("Could not create MCU subscriber for session %s to process sendoffer in %s: %s", message.SendOffer.SessionId, s.PublicId(), err) log.Printf("Could not create MCU subscriber for session %s to process sendoffer in %s: %s", message.SendOffer.SessionId, s.PublicId(), err)
if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{ if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{

View file

@ -222,16 +222,16 @@ func TestBandwidth_Backend(t *testing.T) {
hub.SetMcu(mcu) hub.SetMcu(mcu)
streamTypes := []string{ streamTypes := []StreamType{
streamTypeVideo, StreamTypeVideo,
streamTypeScreen, StreamTypeScreen,
} }
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel() defer cancel()
for _, streamType := range streamTypes { for _, streamType := range streamTypes {
t.Run(streamType, func(t *testing.T) { t.Run(string(streamType), func(t *testing.T) {
client := NewTestClient(t, server, hub) client := NewTestClient(t, server, hub)
defer client.CloseWithBye() defer client.CloseWithBye()
@ -268,7 +268,7 @@ func TestBandwidth_Backend(t *testing.T) {
}, MessageClientMessageData{ }, MessageClientMessageData{
Type: "offer", Type: "offer",
Sid: "54321", Sid: "54321",
RoomType: streamType, RoomType: string(streamType),
Bitrate: bitrate, Bitrate: bitrate,
Payload: map[string]interface{}{ Payload: map[string]interface{}{
"sdp": MockSdpOfferAudioAndVideo, "sdp": MockSdpOfferAudioAndVideo,
@ -287,7 +287,7 @@ func TestBandwidth_Backend(t *testing.T) {
} }
var expectBitrate int var expectBitrate int
if streamType == streamTypeVideo { if streamType == StreamTypeVideo {
expectBitrate = backend.maxStreamBitrate expectBitrate = backend.maxStreamBitrate
} else { } else {
expectBitrate = backend.maxScreenBitrate expectBitrate = backend.maxScreenBitrate

View file

@ -223,13 +223,13 @@ func (c *GrpcClient) IsSessionInCall(ctx context.Context, sessionId string, room
return response.GetInCall(), nil return response.GetInCall(), nil
} }
func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, streamType string) (string, string, net.IP, error) { func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, streamType StreamType) (string, string, net.IP, error) {
statsGrpcClientCalls.WithLabelValues("GetPublisherId").Inc() statsGrpcClientCalls.WithLabelValues("GetPublisherId").Inc()
// TODO: Remove debug logging // TODO: Remove debug logging
log.Printf("Get %s publisher id %s on %s", streamType, sessionId, c.Target()) log.Printf("Get %s publisher id %s on %s", streamType, sessionId, c.Target())
response, err := c.impl.GetPublisherId(ctx, &GetPublisherIdRequest{ response, err := c.impl.GetPublisherId(ctx, &GetPublisherIdRequest{
SessionId: sessionId, SessionId: sessionId,
StreamType: streamType, StreamType: string(streamType),
}, grpc.WaitForReady(true)) }, grpc.WaitForReady(true))
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return "", "", nil, nil return "", "", nil, nil

View file

@ -171,7 +171,7 @@ func (s *GrpcServer) GetPublisherId(ctx context.Context, request *GetPublisherId
return nil, status.Error(codes.NotFound, "no such session") return nil, status.Error(codes.NotFound, "no such session")
} }
publisher := clientSession.GetOrWaitForPublisher(ctx, request.StreamType) publisher := clientSession.GetOrWaitForPublisher(ctx, StreamType(request.StreamType))
if publisher, ok := publisher.(*mcuProxyPublisher); ok { if publisher, ok := publisher.(*mcuProxyPublisher); ok {
reply := &GetPublisherIdReply{ reply := &GetPublisherIdReply{
PublisherId: publisher.Id(), PublisherId: publisher.Id(),

34
hub.go
View file

@ -1445,6 +1445,16 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
// Maybe this is a message to be processed by the MCU. // Maybe this is a message to be processed by the MCU.
var data MessageClientMessageData var data MessageClientMessageData
if err := json.Unmarshal(*msg.Data, &data); err == nil { if err := json.Unmarshal(*msg.Data, &data); err == nil {
if err := data.CheckValid(); err != nil {
log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err)
if err, ok := err.(*Error); ok {
session.SendMessage(message.NewErrorServerMessage(err))
} else {
session.SendMessage(message.NewErrorServerMessage(InvalidFormat))
}
return
}
clientData = &data clientData = &data
switch clientData.Type { switch clientData.Type {
@ -1476,7 +1486,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
return return
} }
publisher := session.GetPublisher(streamTypeScreen) publisher := session.GetPublisher(StreamTypeScreen)
if publisher == nil { if publisher == nil {
return return
} }
@ -1547,6 +1557,16 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
if h.mcu != nil { if h.mcu != nil {
var data MessageClientMessageData var data MessageClientMessageData
if err := json.Unmarshal(*msg.Data, &data); err == nil { if err := json.Unmarshal(*msg.Data, &data); err == nil {
if err := data.CheckValid(); err != nil {
log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err)
if err, ok := err.(*Error); ok {
session.SendMessage(message.NewErrorServerMessage(err))
} else {
session.SendMessage(message.NewErrorServerMessage(InvalidFormat))
}
return
}
clientData = &data clientData = &data
} }
} }
@ -1586,7 +1606,7 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout) ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
defer cancel() defer cancel()
mc, err := recipient.GetOrCreateSubscriber(ctx, h.mcu, session.PublicId(), clientData.RoomType) mc, err := recipient.GetOrCreateSubscriber(ctx, h.mcu, session.PublicId(), StreamType(clientData.RoomType))
if err != nil { if err != nil {
log.Printf("Could not create MCU subscriber for session %s to send %+v to %s: %s", session.PublicId(), clientData, recipient.PublicId(), err) log.Printf("Could not create MCU subscriber for session %s to send %+v to %s: %s", session.PublicId(), clientData, recipient.PublicId(), err)
sendMcuClientNotFound(session, message) sendMcuClientNotFound(session, message)
@ -2145,13 +2165,13 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe
} }
clientType = "subscriber" clientType = "subscriber"
mc, err = session.GetOrCreateSubscriber(ctx, h.mcu, message.Recipient.SessionId, data.RoomType) mc, err = session.GetOrCreateSubscriber(ctx, h.mcu, message.Recipient.SessionId, StreamType(data.RoomType))
case "sendoffer": case "sendoffer":
// Will be sent directly. // Will be sent directly.
return return
case "offer": case "offer":
clientType = "publisher" clientType = "publisher"
mc, err = session.GetOrCreatePublisher(ctx, h.mcu, data.RoomType, data) mc, err = session.GetOrCreatePublisher(ctx, h.mcu, StreamType(data.RoomType), data)
if err, ok := err.(*PermissionError); ok { if err, ok := err.(*PermissionError); ok {
log.Printf("Session %s is not allowed to offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err) log.Printf("Session %s is not allowed to offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err)
sendNotAllowed(session, client_message, "Not allowed to publish.") sendNotAllowed(session, client_message, "Not allowed to publish.")
@ -2169,7 +2189,7 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe
} }
clientType = "subscriber" clientType = "subscriber"
mc = session.GetSubscriber(message.Recipient.SessionId, data.RoomType) mc = session.GetSubscriber(message.Recipient.SessionId, StreamType(data.RoomType))
default: default:
if session.PublicId() == message.Recipient.SessionId { if session.PublicId() == message.Recipient.SessionId {
if err := session.IsAllowedToSend(data); err != nil { if err := session.IsAllowedToSend(data); err != nil {
@ -2179,10 +2199,10 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe
} }
clientType = "publisher" clientType = "publisher"
mc = session.GetPublisher(data.RoomType) mc = session.GetPublisher(StreamType(data.RoomType))
} else { } else {
clientType = "subscriber" clientType = "subscriber"
mc = session.GetSubscriber(message.Recipient.SessionId, data.RoomType) mc = session.GetSubscriber(message.Recipient.SessionId, StreamType(data.RoomType))
} }
} }
if err != nil { if err != nil {

View file

@ -75,14 +75,35 @@ type Mcu interface {
GetStats() interface{} GetStats() interface{}
NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error)
NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error)
}
type StreamType string
const (
StreamTypeAudio StreamType = "audio"
StreamTypeVideo StreamType = "video"
StreamTypeScreen StreamType = "screen"
)
func IsValidStreamType(s string) bool {
switch s {
case string(StreamTypeAudio):
fallthrough
case string(StreamTypeVideo):
fallthrough
case string(StreamTypeScreen):
return true
default:
return false
}
} }
type McuClient interface { type McuClient interface {
Id() string Id() string
Sid() string Sid() string
StreamType() string StreamType() StreamType
Close(ctx context.Context) Close(ctx context.Context)

View file

@ -50,18 +50,19 @@ const (
defaultMaxStreamBitrate = 1024 * 1024 defaultMaxStreamBitrate = 1024 * 1024
defaultMaxScreenBitrate = 2048 * 1024 defaultMaxScreenBitrate = 2048 * 1024
streamTypeVideo = "video"
streamTypeScreen = "screen"
) )
var ( var (
streamTypeUserIds = map[string]uint64{ streamTypeUserIds = map[StreamType]uint64{
streamTypeVideo: videoPublisherUserId, StreamTypeVideo: videoPublisherUserId,
streamTypeScreen: screenPublisherUserId, StreamTypeScreen: screenPublisherUserId,
} }
) )
func getStreamId(publisherId string, streamType StreamType) string {
return fmt.Sprintf("%s|%s", publisherId, streamType)
}
func getPluginValue(data janus.PluginData, pluginName string, key string) interface{} { func getPluginValue(data janus.PluginData, pluginName string, key string) interface{} {
if data.Plugin != pluginName { if data.Plugin != pluginName {
return nil return nil
@ -436,7 +437,7 @@ type mcuJanusClient struct {
session uint64 session uint64
roomId uint64 roomId uint64
sid string sid string
streamType string streamType StreamType
handle *JanusHandle handle *JanusHandle
handleId uint64 handleId uint64
@ -459,7 +460,7 @@ func (c *mcuJanusClient) Sid() string {
return c.sid return c.sid
} }
func (c *mcuJanusClient) StreamType() string { func (c *mcuJanusClient) StreamType() StreamType {
return c.streamType return c.streamType
} }
@ -609,7 +610,7 @@ func (c *mcuJanusClient) selectStream(ctx context.Context, stream *streamSelecti
type publisherStatsCounter struct { type publisherStatsCounter struct {
mu sync.Mutex mu sync.Mutex
streamTypes map[string]bool streamTypes map[StreamType]bool
subscribers map[string]bool subscribers map[string]bool
} }
@ -619,14 +620,14 @@ func (c *publisherStatsCounter) Reset() {
count := len(c.subscribers) count := len(c.subscribers)
for streamType := range c.streamTypes { for streamType := range c.streamTypes {
statsMcuPublisherStreamTypesCurrent.WithLabelValues(streamType).Dec() statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec()
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Sub(float64(count)) statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(count))
} }
c.streamTypes = nil c.streamTypes = nil
c.subscribers = nil c.subscribers = nil
} }
func (c *publisherStatsCounter) EnableStream(streamType string, enable bool) { func (c *publisherStatsCounter) EnableStream(streamType StreamType, enable bool) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -636,15 +637,15 @@ func (c *publisherStatsCounter) EnableStream(streamType string, enable bool) {
if enable { if enable {
if c.streamTypes == nil { if c.streamTypes == nil {
c.streamTypes = make(map[string]bool) c.streamTypes = make(map[StreamType]bool)
} }
c.streamTypes[streamType] = true c.streamTypes[streamType] = true
statsMcuPublisherStreamTypesCurrent.WithLabelValues(streamType).Inc() statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Inc()
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Add(float64(len(c.subscribers))) statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Add(float64(len(c.subscribers)))
} else { } else {
delete(c.streamTypes, streamType) delete(c.streamTypes, streamType)
statsMcuPublisherStreamTypesCurrent.WithLabelValues(streamType).Dec() statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec()
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Sub(float64(len(c.subscribers))) statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(len(c.subscribers)))
} }
} }
@ -661,7 +662,7 @@ func (c *publisherStatsCounter) AddSubscriber(id string) {
} }
c.subscribers[id] = true c.subscribers[id] = true
for streamType := range c.streamTypes { for streamType := range c.streamTypes {
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Inc() statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Inc()
} }
} }
@ -675,7 +676,7 @@ func (c *publisherStatsCounter) RemoveSubscriber(id string) {
delete(c.subscribers, id) delete(c.subscribers, id)
for streamType := range c.streamTypes { for streamType := range c.streamTypes {
statsMcuSubscriberStreamTypesCurrent.WithLabelValues(streamType).Dec() statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Dec()
} }
} }
@ -688,20 +689,20 @@ type mcuJanusPublisher struct {
stats publisherStatsCounter stats publisherStatsCounter
} }
func (m *mcuJanus) SubscriberConnected(id string, publisher string, streamType string) { func (m *mcuJanus) SubscriberConnected(id string, publisher string, streamType StreamType) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if p, found := m.publishers[publisher+"|"+streamType]; found { if p, found := m.publishers[getStreamId(publisher, streamType)]; found {
p.stats.AddSubscriber(id) p.stats.AddSubscriber(id)
} }
} }
func (m *mcuJanus) SubscriberDisconnected(id string, publisher string, streamType string) { func (m *mcuJanus) SubscriberDisconnected(id string, publisher string, streamType StreamType) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if p, found := m.publishers[publisher+"|"+streamType]; found { if p, found := m.publishers[getStreamId(publisher, streamType)]; found {
p.stats.RemoveSubscriber(id) p.stats.RemoveSubscriber(id)
} }
} }
@ -714,7 +715,7 @@ func min(a, b int) int {
return b return b
} }
func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType string, bitrate int) (*JanusHandle, uint64, uint64, error) { func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType StreamType, bitrate int) (*JanusHandle, uint64, uint64, error) {
session := m.session session := m.session
if session == nil { if session == nil {
return nil, 0, 0, ErrNotConnected return nil, 0, 0, ErrNotConnected
@ -727,7 +728,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st
log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id) log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id)
create_msg := map[string]interface{}{ create_msg := map[string]interface{}{
"request": "create", "request": "create",
"description": id + "|" + streamType, "description": getStreamId(id, streamType),
// We publish every stream in its own Janus room. // We publish every stream in its own Janus room.
"publishers": 1, "publishers": 1,
// Do not use the video-orientation RTP extension as it breaks video // Do not use the video-orientation RTP extension as it breaks video
@ -735,7 +736,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st
"videoorient_ext": false, "videoorient_ext": false,
} }
var maxBitrate int var maxBitrate int
if streamType == streamTypeScreen { if streamType == StreamTypeScreen {
maxBitrate = m.maxScreenBitrate maxBitrate = m.maxScreenBitrate
} else { } else {
maxBitrate = m.maxStreamBitrate maxBitrate = m.maxStreamBitrate
@ -782,7 +783,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st
return handle, response.Session, roomId, nil return handle, response.Session, roomId, nil
} }
func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) {
if _, found := streamTypeUserIds[streamType]; !found { if _, found := streamTypeUserIds[streamType]; !found {
return nil, fmt.Errorf("Unsupported stream type %s", streamType) return nil, fmt.Errorf("Unsupported stream type %s", streamType)
} }
@ -823,11 +824,11 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id st
log.Printf("Publisher %s is using handle %d", client.id, client.handleId) log.Printf("Publisher %s is using handle %d", client.id, client.handleId)
go client.run(handle, client.closeChan) go client.run(handle, client.closeChan)
m.mu.Lock() m.mu.Lock()
m.publishers[id+"|"+streamType] = client m.publishers[getStreamId(id, streamType)] = client
m.publisherCreated.Notify(id + "|" + streamType) m.publisherCreated.Notify(getStreamId(id, streamType))
m.mu.Unlock() m.mu.Unlock()
statsPublishersCurrent.WithLabelValues(streamType).Inc() statsPublishersCurrent.WithLabelValues(string(streamType)).Inc()
statsPublishersTotal.WithLabelValues(streamType).Inc() statsPublishersTotal.WithLabelValues(string(streamType)).Inc()
return client, nil return client, nil
} }
@ -860,7 +861,7 @@ func (p *mcuJanusPublisher) handleDetached(event *janus.DetachedMsg) {
func (p *mcuJanusPublisher) handleConnected(event *janus.WebRTCUpMsg) { func (p *mcuJanusPublisher) handleConnected(event *janus.WebRTCUpMsg) {
log.Printf("Publisher %d received connected", p.handleId) log.Printf("Publisher %d received connected", p.handleId)
p.mcu.publisherConnected.Notify(p.id + "|" + p.streamType) p.mcu.publisherConnected.Notify(getStreamId(p.id, p.streamType))
} }
func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) { func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) {
@ -872,8 +873,8 @@ func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) {
} }
func (p *mcuJanusPublisher) handleMedia(event *janus.MediaMsg) { func (p *mcuJanusPublisher) handleMedia(event *janus.MediaMsg) {
mediaType := event.Type mediaType := StreamType(event.Type)
if mediaType == "video" && p.streamType == "screen" { if mediaType == StreamTypeVideo && p.streamType == StreamTypeScreen {
// We want to differentiate between audio, video and screensharing // We want to differentiate between audio, video and screensharing
mediaType = p.streamType mediaType = p.streamType
} }
@ -920,7 +921,7 @@ func (p *mcuJanusPublisher) Close(ctx context.Context) {
log.Printf("Room %d destroyed", p.roomId) log.Printf("Room %d destroyed", p.roomId)
} }
p.mcu.mu.Lock() p.mcu.mu.Lock()
delete(p.mcu.publishers, p.id+"|"+p.streamType) delete(p.mcu.publishers, getStreamId(p.id, p.streamType))
p.mcu.mu.Unlock() p.mcu.mu.Unlock()
p.roomId = 0 p.roomId = 0
notify = true notify = true
@ -931,7 +932,7 @@ func (p *mcuJanusPublisher) Close(ctx context.Context) {
p.stats.Reset() p.stats.Reset()
if notify { if notify {
statsPublishersCurrent.WithLabelValues(p.streamType).Dec() statsPublishersCurrent.WithLabelValues(string(p.streamType)).Dec()
p.mcu.unregisterClient(p) p.mcu.unregisterClient(p)
p.listener.PublisherClosed(p) p.listener.PublisherClosed(p)
} }
@ -975,9 +976,9 @@ type mcuJanusSubscriber struct {
publisher string publisher string
} }
func (m *mcuJanus) getPublisher(ctx context.Context, publisher string, streamType string) (*mcuJanusPublisher, error) { func (m *mcuJanus) getPublisher(ctx context.Context, publisher string, streamType StreamType) (*mcuJanusPublisher, error) {
// Do the direct check immediately as this should be the normal case. // Do the direct check immediately as this should be the normal case.
key := publisher + "|" + streamType key := getStreamId(publisher, streamType)
m.mu.Lock() m.mu.Lock()
if result, found := m.publishers[key]; found { if result, found := m.publishers[key]; found {
m.mu.Unlock() m.mu.Unlock()
@ -1002,7 +1003,7 @@ func (m *mcuJanus) getPublisher(ctx context.Context, publisher string, streamTyp
} }
} }
func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher string, streamType string) (*JanusHandle, *mcuJanusPublisher, error) { func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher string, streamType StreamType) (*JanusHandle, *mcuJanusPublisher, error) {
var pub *mcuJanusPublisher var pub *mcuJanusPublisher
var err error var err error
if pub, err = m.getPublisher(ctx, publisher, streamType); err != nil { if pub, err = m.getPublisher(ctx, publisher, streamType); err != nil {
@ -1023,7 +1024,7 @@ func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher st
return handle, pub, nil return handle, pub, nil
} }
func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) {
if _, found := streamTypeUserIds[streamType]; !found { if _, found := streamTypeUserIds[streamType]; !found {
return nil, fmt.Errorf("Unsupported stream type %s", streamType) return nil, fmt.Errorf("Unsupported stream type %s", streamType)
} }
@ -1058,8 +1059,8 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ
client.mcuJanusClient.handleMedia = client.handleMedia client.mcuJanusClient.handleMedia = client.handleMedia
m.registerClient(client) m.registerClient(client)
go client.run(handle, client.closeChan) go client.run(handle, client.closeChan)
statsSubscribersCurrent.WithLabelValues(streamType).Inc() statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc()
statsSubscribersTotal.WithLabelValues(streamType).Inc() statsSubscribersTotal.WithLabelValues(string(streamType)).Inc()
return client, nil return client, nil
} }
@ -1144,7 +1145,7 @@ func (p *mcuJanusSubscriber) Close(ctx context.Context) {
if closed { if closed {
p.mcu.SubscriberDisconnected(p.Id(), p.publisher, p.streamType) p.mcu.SubscriberDisconnected(p.Id(), p.publisher, p.streamType)
statsSubscribersCurrent.WithLabelValues(p.streamType).Dec() statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Dec()
} }
p.mcu.unregisterClient(p) p.mcu.unregisterClient(p)
p.listener.SubscriberClosed(p) p.listener.SubscriberClosed(p)
@ -1158,7 +1159,7 @@ func (p *mcuJanusSubscriber) joinRoom(ctx context.Context, stream *streamSelecti
return return
} }
waiter := p.mcu.publisherConnected.NewWaiter(p.publisher + "|" + p.streamType) waiter := p.mcu.publisherConnected.NewWaiter(getStreamId(p.publisher, p.streamType))
defer p.mcu.publisherConnected.Release(waiter) defer p.mcu.publisherConnected.Release(waiter)
loggedNotPublishingYet := false loggedNotPublishingYet := false
@ -1223,7 +1224,7 @@ retry:
if !loggedNotPublishingYet { if !loggedNotPublishingYet {
loggedNotPublishingYet = true loggedNotPublishingYet = true
statsWaitingForPublisherTotal.WithLabelValues(p.streamType).Inc() statsWaitingForPublisherTotal.WithLabelValues(string(p.streamType)).Inc()
} }
if err := waiter.Wait(ctx); err != nil { if err := waiter.Wait(ctx); err != nil {

View file

@ -76,7 +76,7 @@ type McuProxy interface {
type mcuProxyPubSubCommon struct { type mcuProxyPubSubCommon struct {
sid string sid string
streamType string streamType StreamType
proxyId string proxyId string
conn *mcuProxyConnection conn *mcuProxyConnection
listener McuListener listener McuListener
@ -90,7 +90,7 @@ func (c *mcuProxyPubSubCommon) Sid() string {
return c.sid return c.sid
} }
func (c *mcuProxyPubSubCommon) StreamType() string { func (c *mcuProxyPubSubCommon) StreamType() StreamType {
return c.streamType return c.streamType
} }
@ -132,7 +132,7 @@ type mcuProxyPublisher struct {
mediaTypes MediaType mediaTypes MediaType
} }
func newMcuProxyPublisher(id string, sid string, streamType string, mediaTypes MediaType, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { func newMcuProxyPublisher(id string, sid string, streamType StreamType, mediaTypes MediaType, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher {
return &mcuProxyPublisher{ return &mcuProxyPublisher{
mcuProxyPubSubCommon: mcuProxyPubSubCommon{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{
sid: sid, sid: sid,
@ -217,7 +217,7 @@ type mcuProxySubscriber struct {
publisherId string publisherId string
} }
func newMcuProxySubscriber(publisherId string, sid string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber { func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber {
return &mcuProxySubscriber{ return &mcuProxySubscriber{
mcuProxyPubSubCommon: mcuProxyPubSubCommon{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{
sid: sid, sid: sid,
@ -719,9 +719,9 @@ func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) {
if _, found := c.publishers[publisher.proxyId]; found { if _, found := c.publishers[publisher.proxyId]; found {
delete(c.publishers, publisher.proxyId) delete(c.publishers, publisher.proxyId)
statsPublishersCurrent.WithLabelValues(publisher.StreamType()).Dec() statsPublishersCurrent.WithLabelValues(string(publisher.StreamType())).Dec()
} }
delete(c.publisherIds, publisher.id+"|"+publisher.StreamType()) delete(c.publisherIds, getStreamId(publisher.id, publisher.StreamType()))
if len(c.publishers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) { if len(c.publishers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) {
go c.closeIfEmpty() go c.closeIfEmpty()
@ -751,7 +751,7 @@ func (c *mcuProxyConnection) removeSubscriber(subscriber *mcuProxySubscriber) {
if _, found := c.subscribers[subscriber.proxyId]; found { if _, found := c.subscribers[subscriber.proxyId]; found {
delete(c.subscribers, subscriber.proxyId) delete(c.subscribers, subscriber.proxyId)
statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec() statsSubscribersCurrent.WithLabelValues(string(subscriber.StreamType())).Dec()
} }
if len(c.subscribers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) { if len(c.subscribers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) {
@ -1032,7 +1032,7 @@ func (c *mcuProxyConnection) performSyncRequest(ctx context.Context, msg *ProxyC
} }
} }
func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType) (McuPublisher, error) { func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType) (McuPublisher, error) {
msg := &ProxyClientMessage{ msg := &ProxyClientMessage{
Type: "command", Type: "command",
Command: &CommandProxyClientMessage{ Command: &CommandProxyClientMessage{
@ -1057,14 +1057,14 @@ func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListe
publisher := newMcuProxyPublisher(id, sid, streamType, mediaTypes, proxyId, c, listener) publisher := newMcuProxyPublisher(id, sid, streamType, mediaTypes, proxyId, c, listener)
c.publishersLock.Lock() c.publishersLock.Lock()
c.publishers[proxyId] = publisher c.publishers[proxyId] = publisher
c.publisherIds[id+"|"+streamType] = proxyId c.publisherIds[getStreamId(id, streamType)] = proxyId
c.publishersLock.Unlock() c.publishersLock.Unlock()
statsPublishersCurrent.WithLabelValues(streamType).Inc() statsPublishersCurrent.WithLabelValues(string(streamType)).Inc()
statsPublishersTotal.WithLabelValues(streamType).Inc() statsPublishersTotal.WithLabelValues(string(streamType)).Inc()
return publisher, nil return publisher, nil
} }
func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType string) (McuSubscriber, error) { func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType StreamType) (McuSubscriber, error) {
msg := &ProxyClientMessage{ msg := &ProxyClientMessage{
Type: "command", Type: "command",
Command: &CommandProxyClientMessage{ Command: &CommandProxyClientMessage{
@ -1088,8 +1088,8 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList
c.subscribersLock.Lock() c.subscribersLock.Lock()
c.subscribers[proxyId] = subscriber c.subscribers[proxyId] = subscriber
c.subscribersLock.Unlock() c.subscribersLock.Unlock()
statsSubscribersCurrent.WithLabelValues(streamType).Inc() statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc()
statsSubscribersTotal.WithLabelValues(streamType).Inc() statsSubscribersTotal.WithLabelValues(string(streamType)).Inc()
return subscriber, nil return subscriber, nil
} }
@ -1555,10 +1555,10 @@ func (m *mcuProxy) removePublisher(publisher *mcuProxyPublisher) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
delete(m.publishers, publisher.id+"|"+publisher.StreamType()) delete(m.publishers, getStreamId(publisher.id, publisher.StreamType()))
} }
func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) {
connections := m.getSortedConnections(initiator) connections := m.getSortedConnections(initiator)
for _, conn := range connections { for _, conn := range connections {
if conn.IsShutdownScheduled() || conn.IsTemporary() { if conn.IsShutdownScheduled() || conn.IsTemporary() {
@ -1569,7 +1569,7 @@ func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id st
defer cancel() defer cancel()
var maxBitrate int var maxBitrate int
if streamType == streamTypeScreen { if streamType == StreamTypeScreen {
maxBitrate = m.maxScreenBitrate maxBitrate = m.maxScreenBitrate
} else { } else {
maxBitrate = m.maxStreamBitrate maxBitrate = m.maxStreamBitrate
@ -1586,28 +1586,28 @@ func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id st
} }
m.mu.Lock() m.mu.Lock()
m.publishers[id+"|"+streamType] = conn m.publishers[getStreamId(id, streamType)] = conn
m.mu.Unlock() m.mu.Unlock()
m.publisherWaiters.Wakeup() m.publisherWaiters.Wakeup()
return publisher, nil return publisher, nil
} }
statsProxyNobackendAvailableTotal.WithLabelValues(streamType).Inc() statsProxyNobackendAvailableTotal.WithLabelValues(string(streamType)).Inc()
return nil, fmt.Errorf("No MCU connection available") return nil, fmt.Errorf("No MCU connection available")
} }
func (m *mcuProxy) getPublisherConnection(publisher string, streamType string) *mcuProxyConnection { func (m *mcuProxy) getPublisherConnection(publisher string, streamType StreamType) *mcuProxyConnection {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
return m.publishers[publisher+"|"+streamType] return m.publishers[getStreamId(publisher, streamType)]
} }
func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher string, streamType string) *mcuProxyConnection { func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher string, streamType StreamType) *mcuProxyConnection {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
conn := m.publishers[publisher+"|"+streamType] conn := m.publishers[getStreamId(publisher, streamType)]
if conn != nil { if conn != nil {
// Publisher was created while waiting for lock. // Publisher was created while waiting for lock.
return conn return conn
@ -1617,13 +1617,13 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str
id := m.publisherWaiters.Add(ch) id := m.publisherWaiters.Add(ch)
defer m.publisherWaiters.Remove(id) defer m.publisherWaiters.Remove(id)
statsWaitingForPublisherTotal.WithLabelValues(streamType).Inc() statsWaitingForPublisherTotal.WithLabelValues(string(streamType)).Inc()
for { for {
m.mu.Unlock() m.mu.Unlock()
select { select {
case <-ch: case <-ch:
m.mu.Lock() m.mu.Lock()
conn = m.publishers[publisher+"|"+streamType] conn = m.publishers[getStreamId(publisher, streamType)]
if conn != nil { if conn != nil {
return conn return conn
} }
@ -1634,11 +1634,11 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str
} }
} }
func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) {
if conn := m.getPublisherConnection(publisher, streamType); conn != nil { if conn := m.getPublisherConnection(publisher, streamType); conn != nil {
// Fast common path: publisher is available locally. // Fast common path: publisher is available locally.
conn.publishersLock.Lock() conn.publishersLock.Lock()
id, found := conn.publisherIds[publisher+"|"+streamType] id, found := conn.publisherIds[getStreamId(publisher, streamType)]
conn.publishersLock.Unlock() conn.publishersLock.Unlock()
if !found { if !found {
return nil, fmt.Errorf("Unknown publisher %s", publisher) return nil, fmt.Errorf("Unknown publisher %s", publisher)
@ -1658,7 +1658,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ
cancel() // Cancel pending RPC calls. cancel() // Cancel pending RPC calls.
conn.publishersLock.Lock() conn.publishersLock.Lock()
id, found := conn.publisherIds[publisher+"|"+streamType] id, found := conn.publisherIds[getStreamId(publisher, streamType)]
conn.publishersLock.Unlock() conn.publishersLock.Unlock()
if !found { if !found {
log.Printf("Unknown id for local %s publisher %s", streamType, publisher) log.Printf("Unknown id for local %s publisher %s", streamType, publisher)

View file

@ -69,9 +69,9 @@ func (m *TestMCU) GetStats() interface{} {
return nil return nil
} }
func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) {
var maxBitrate int var maxBitrate int
if streamType == streamTypeScreen { if streamType == StreamTypeScreen {
maxBitrate = TestMaxBitrateScreen maxBitrate = TestMaxBitrateScreen
} else { } else {
maxBitrate = TestMaxBitrateVideo maxBitrate = TestMaxBitrateVideo
@ -117,7 +117,7 @@ func (m *TestMCU) GetPublisher(id string) *TestMCUPublisher {
return m.publishers[id] return m.publishers[id]
} }
func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@ -143,7 +143,7 @@ type TestMCUClient struct {
id string id string
sid string sid string
streamType string streamType StreamType
} }
func (c *TestMCUClient) Id() string { func (c *TestMCUClient) Id() string {
@ -154,7 +154,7 @@ func (c *TestMCUClient) Sid() string {
return c.sid return c.sid
} }
func (c *TestMCUClient) StreamType() string { func (c *TestMCUClient) StreamType() StreamType {
return c.streamType return c.streamType
} }

View file

@ -657,8 +657,8 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
}, },
} }
session.sendMessage(response) session.sendMessage(response)
statsPublishersCurrent.WithLabelValues(cmd.StreamType).Inc() statsPublishersCurrent.WithLabelValues(string(cmd.StreamType)).Inc()
statsPublishersTotal.WithLabelValues(cmd.StreamType).Inc() statsPublishersTotal.WithLabelValues(string(cmd.StreamType)).Inc()
case "create-subscriber": case "create-subscriber":
id := uuid.New().String() id := uuid.New().String()
publisherId := cmd.PublisherId publisherId := cmd.PublisherId
@ -686,8 +686,8 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
}, },
} }
session.sendMessage(response) session.sendMessage(response)
statsSubscribersCurrent.WithLabelValues(cmd.StreamType).Inc() statsSubscribersCurrent.WithLabelValues(string(cmd.StreamType)).Inc()
statsSubscribersTotal.WithLabelValues(cmd.StreamType).Inc() statsSubscribersTotal.WithLabelValues(string(cmd.StreamType)).Inc()
case "delete-publisher": case "delete-publisher":
client := s.GetClient(cmd.ClientId) client := s.GetClient(cmd.ClientId)
if client == nil { if client == nil {
@ -707,7 +707,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
} }
if s.DeleteClient(cmd.ClientId, client) { if s.DeleteClient(cmd.ClientId, client) {
statsPublishersCurrent.WithLabelValues(client.StreamType()).Dec() statsPublishersCurrent.WithLabelValues(string(client.StreamType())).Dec()
} }
go func() { go func() {
@ -742,7 +742,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
} }
if s.DeleteClient(cmd.ClientId, client) { if s.DeleteClient(cmd.ClientId, client) {
statsSubscribersCurrent.WithLabelValues(client.StreamType()).Dec() statsSubscribersCurrent.WithLabelValues(string(client.StreamType())).Dec()
} }
go func() { go func() {

View file

@ -212,7 +212,7 @@ func (s *ProxySession) SubscriberSidUpdated(subscriber signaling.McuSubscriber)
func (s *ProxySession) PublisherClosed(publisher signaling.McuPublisher) { func (s *ProxySession) PublisherClosed(publisher signaling.McuPublisher) {
if id := s.DeletePublisher(publisher); id != "" { if id := s.DeletePublisher(publisher); id != "" {
if s.proxy.DeleteClient(id, publisher) { if s.proxy.DeleteClient(id, publisher) {
statsPublishersCurrent.WithLabelValues(publisher.StreamType()).Dec() statsPublishersCurrent.WithLabelValues(string(publisher.StreamType())).Dec()
} }
msg := &signaling.ProxyServerMessage{ msg := &signaling.ProxyServerMessage{
@ -229,7 +229,7 @@ func (s *ProxySession) PublisherClosed(publisher signaling.McuPublisher) {
func (s *ProxySession) SubscriberClosed(subscriber signaling.McuSubscriber) { func (s *ProxySession) SubscriberClosed(subscriber signaling.McuSubscriber) {
if id := s.DeleteSubscriber(subscriber); id != "" { if id := s.DeleteSubscriber(subscriber); id != "" {
if s.proxy.DeleteClient(id, subscriber) { if s.proxy.DeleteClient(id, subscriber) {
statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec() statsSubscribersCurrent.WithLabelValues(string(subscriber.StreamType())).Dec()
} }
msg := &signaling.ProxyServerMessage{ msg := &signaling.ProxyServerMessage{
@ -294,7 +294,7 @@ func (s *ProxySession) clearPublishers() {
go func(publishers map[string]signaling.McuPublisher) { go func(publishers map[string]signaling.McuPublisher) {
for id, publisher := range publishers { for id, publisher := range publishers {
if s.proxy.DeleteClient(id, publisher) { if s.proxy.DeleteClient(id, publisher) {
statsPublishersCurrent.WithLabelValues(publisher.StreamType()).Dec() statsPublishersCurrent.WithLabelValues(string(publisher.StreamType())).Dec()
} }
publisher.Close(context.Background()) publisher.Close(context.Background())
} }
@ -310,7 +310,7 @@ func (s *ProxySession) clearSubscribers() {
go func(subscribers map[string]signaling.McuSubscriber) { go func(subscribers map[string]signaling.McuSubscriber) {
for id, subscriber := range subscribers { for id, subscriber := range subscribers {
if s.proxy.DeleteClient(id, subscriber) { if s.proxy.DeleteClient(id, subscriber) {
statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec() statsSubscribersCurrent.WithLabelValues(string(subscriber.StreamType())).Dec()
} }
subscriber.Close(context.Background()) subscriber.Close(context.Background())
} }