diff --git a/mcu_janus.go b/mcu_janus.go index 1164736..06f6a9b 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -23,6 +23,7 @@ package signaling import ( "context" + "database/sql" "encoding/json" "fmt" "log" @@ -580,14 +581,14 @@ func (c *mcuJanusClient) handleTrickle(event *TrickleMsg) { } } -func (c *mcuJanusClient) selectStream(ctx context.Context, substream int, temporal int, callback func(error, map[string]interface{})) { +func (c *mcuJanusClient) selectStream(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { handle := c.handle if handle == nil { callback(ErrNotConnected, nil) return } - if substream < 0 && temporal < 0 { + if stream == nil || !stream.HasValues() { callback(nil, nil) return } @@ -595,11 +596,8 @@ func (c *mcuJanusClient) selectStream(ctx context.Context, substream int, tempor configure_msg := map[string]interface{}{ "request": "configure", } - if substream >= 0 { - configure_msg["substream"] = substream - } - if temporal >= 0 { - configure_msg["temporal"] = temporal + if stream != nil { + stream.AddToMessage(configure_msg) } _, err := handle.Message(ctx, configure_msg, nil) if err != nil { @@ -1155,7 +1153,7 @@ func (p *mcuJanusSubscriber) Close(ctx context.Context) { p.mcuJanusClient.Close(ctx) } -func (p *mcuJanusSubscriber) joinRoom(ctx context.Context, callback func(error, map[string]interface{})) { +func (p *mcuJanusSubscriber) joinRoom(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { handle := p.handle if handle == nil { callback(ErrNotConnected, nil) @@ -1173,6 +1171,9 @@ retry: "room": p.roomId, "feed": streamTypeUserIds[p.streamType], } + if stream != nil { + stream.AddToMessage(join_msg) + } join_response, err := handle.Message(ctx, join_msg, nil) if err != nil { callback(err, nil) @@ -1245,7 +1246,7 @@ retry: callback(nil, join_response.Jsep) } -func (p *mcuJanusSubscriber) update(ctx context.Context, callback func(error, map[string]interface{})) { +func (p *mcuJanusSubscriber) update(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { handle := p.handle if handle == nil { callback(ErrNotConnected, nil) @@ -1256,6 +1257,9 @@ func (p *mcuJanusSubscriber) update(ctx context.Context, callback func(error, ma "request": "configure", "update": true, } + if stream != nil { + stream.AddToMessage(configure_msg) + } configure_response, err := handle.Message(ctx, configure_msg, nil) if err != nil { callback(err, nil) @@ -1265,6 +1269,89 @@ func (p *mcuJanusSubscriber) update(ctx context.Context, callback func(error, ma callback(nil, configure_response.Jsep) } +type streamSelection struct { + substream sql.NullInt16 + temporal sql.NullInt16 + audio sql.NullBool + video sql.NullBool +} + +func (s *streamSelection) HasValues() bool { + return s.substream.Valid || s.temporal.Valid || s.audio.Valid || s.video.Valid +} + +func (s *streamSelection) AddToMessage(message map[string]interface{}) { + if s.substream.Valid { + message["substream"] = s.substream.Int16 + } + if s.temporal.Valid { + message["temporal"] = s.temporal.Int16 + } + if s.audio.Valid { + message["audio"] = s.audio.Bool + } + if s.video.Valid { + message["video"] = s.video.Bool + } +} + +func parseStreamSelection(payload map[string]interface{}) (*streamSelection, error) { + var stream streamSelection + if value, found := payload["substream"]; found { + switch value := value.(type) { + case int: + stream.substream.Valid = true + stream.substream.Int16 = int16(value) + case float32: + stream.substream.Valid = true + stream.substream.Int16 = int16(value) + case float64: + stream.substream.Valid = true + stream.substream.Int16 = int16(value) + default: + return nil, fmt.Errorf("Unsupported substream value: %v", value) + } + } + + if value, found := payload["temporal"]; found { + switch value := value.(type) { + case int: + stream.temporal.Valid = true + stream.temporal.Int16 = int16(value) + case float32: + stream.temporal.Valid = true + stream.temporal.Int16 = int16(value) + case float64: + stream.temporal.Valid = true + stream.temporal.Int16 = int16(value) + default: + return nil, fmt.Errorf("Unsupported temporal value: %v", value) + } + } + + if value, found := payload["audio"]; found { + switch value := value.(type) { + case bool: + stream.audio.Valid = true + stream.audio.Bool = value + default: + return nil, fmt.Errorf("Unsupported audio value: %v", value) + } + } + + if value, found := payload["video"]; found { + switch value := value.(type) { + case bool: + stream.video.Valid = true + stream.video.Bool = value + default: + return nil, fmt.Errorf("Unsupported video value: %v", value) + } + } + + return &stream, nil +} + func (p *mcuJanusSubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { statsMcuMessagesTotal.WithLabelValues(data.Type).Inc() jsep_msg := data.Payload @@ -1276,10 +1363,16 @@ func (p *mcuJanusSubscriber) SendMessage(ctx context.Context, message *MessageCl msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) defer cancel() + stream, err := parseStreamSelection(jsep_msg) + if err != nil { + go callback(err, nil) + return + } + if data.Sid == "" || data.Sid != p.Sid() { - p.joinRoom(msgctx, callback) + p.joinRoom(msgctx, stream, callback) } else { - p.update(msgctx, callback) + p.update(msgctx, stream, callback) } } case "answer": @@ -1307,35 +1400,13 @@ func (p *mcuJanusSubscriber) SendMessage(ctx context.Context, message *MessageCl case "endOfCandidates": // Ignore case "selectStream": - substream := -1 - if s, found := jsep_msg["substream"]; found { - switch s := s.(type) { - case int: - substream = s - case float32: - substream = int(s) - case float64: - substream = int(s) - default: - go callback(fmt.Errorf("Unsupported substream value: %v", s), nil) - return - } + stream, err := parseStreamSelection(jsep_msg) + if err != nil { + go callback(err, nil) + return } - temporal := -1 - if s, found := jsep_msg["temporal"]; found { - switch s := s.(type) { - case int: - temporal = s - case float32: - temporal = int(s) - case float64: - temporal = int(s) - default: - go callback(fmt.Errorf("Unsupported temporal value: %v", s), nil) - return - } - } - if substream == -1 && temporal == -1 { + + if stream == nil || !stream.HasValues() { // Nothing to do go callback(nil, nil) return @@ -1345,7 +1416,7 @@ func (p *mcuJanusSubscriber) SendMessage(ctx context.Context, message *MessageCl msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) defer cancel() - p.selectStream(msgctx, substream, temporal, callback) + p.selectStream(msgctx, stream, callback) } default: // Return error asynchronously