diff --git a/api_proxy.go b/api_proxy.go index ccf953a..83f3058 100644 --- a/api_proxy.go +++ b/api_proxy.go @@ -24,6 +24,7 @@ package signaling import ( "encoding/json" "fmt" + "net/url" "github.com/golang-jwt/jwt/v4" ) @@ -48,6 +49,14 @@ type ProxyClientMessage struct { Payload *PayloadProxyClientMessage `json:"payload,omitempty"` } +func (m *ProxyClientMessage) String() string { + data, err := json.Marshal(m) + if err != nil { + return fmt.Sprintf("Could not serialize %#v: %s", m, err) + } + return string(data) +} + func (m *ProxyClientMessage) CheckValid() error { switch m.Type { case "": @@ -115,6 +124,14 @@ type ProxyServerMessage struct { Event *EventProxyServerMessage `json:"event,omitempty"` } +func (r *ProxyServerMessage) String() string { + data, err := json.Marshal(r) + if err != nil { + return fmt.Sprintf("Could not serialize %#v: %s", r, err) + } + return string(data) +} + func (r *ProxyServerMessage) CloseAfterSend(session Session) bool { switch r.Type { case "bye": @@ -185,6 +202,14 @@ type CommandProxyClientMessage struct { ClientId string `json:"clientId,omitempty"` Bitrate int `json:"bitrate,omitempty"` MediaTypes MediaType `json:"mediatypes,omitempty"` + + RemoteUrl string `json:"remoteUrl,omitempty"` + remoteUrl *url.URL + RemoteToken string `json:"remoteToken,omitempty"` + + Hostname string `json:"hostname,omitempty"` + Port int `json:"port,omitempty"` + RtcpPort int `json:"rtcpPort,omitempty"` } func (m *CommandProxyClientMessage) CheckValid() error { @@ -202,6 +227,17 @@ func (m *CommandProxyClientMessage) CheckValid() error { if m.StreamType == "" { return fmt.Errorf("stream type missing") } + if m.RemoteUrl != "" { + if m.RemoteToken == "" { + return fmt.Errorf("remote token missing") + } + + remoteUrl, err := url.Parse(m.RemoteUrl) + if err != nil { + return fmt.Errorf("invalid remote url: %w", err) + } + m.remoteUrl = remoteUrl + } case "delete-publisher": fallthrough case "delete-subscriber": @@ -217,6 +253,8 @@ type CommandProxyServerMessage struct { Sid string `json:"sid,omitempty"` Bitrate int `json:"bitrate,omitempty"` + + Streams []PublisherStream `json:"streams,omitempty"` } // Type "payload" @@ -261,12 +299,41 @@ type PayloadProxyServerMessage struct { // Type "event" +type EventProxyServerBandwidth struct { + // Incoming is the bandwidth utilization for publishers in percent. + Incoming *float64 `json:"incoming,omitempty"` + // Outgoing is the bandwidth utilization for subscribers in percent. + Outgoing *float64 `json:"outgoing,omitempty"` +} + +func (b *EventProxyServerBandwidth) String() string { + if b.Incoming != nil && b.Outgoing != nil { + return fmt.Sprintf("bandwidth: incoming=%.3f%%, outgoing=%.3f%%", *b.Incoming, *b.Outgoing) + } else if b.Incoming != nil { + return fmt.Sprintf("bandwidth: incoming=%.3f%%, outgoing=unlimited", *b.Incoming) + } else if b.Outgoing != nil { + return fmt.Sprintf("bandwidth: incoming=unlimited, outgoing=%.3f%%", *b.Outgoing) + } else { + return "bandwidth: incoming=unlimited, outgoing=unlimited" + } +} + +func (b EventProxyServerBandwidth) AllowIncoming() bool { + return b.Incoming == nil || *b.Incoming < 100 +} + +func (b EventProxyServerBandwidth) AllowOutgoing() bool { + return b.Outgoing == nil || *b.Outgoing < 100 +} + type EventProxyServerMessage struct { Type string `json:"type"` ClientId string `json:"clientId,omitempty"` Load int64 `json:"load,omitempty"` Sid string `json:"sid,omitempty"` + + Bandwidth *EventProxyServerBandwidth `json:"bandwidth,omitempty"` } // Information on a proxy in the etcd cluster. diff --git a/clientsession.go b/clientsession.go index 72c22d1..d4e8c40 100644 --- a/clientsession.go +++ b/clientsession.go @@ -934,9 +934,10 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s subscriber, found := s.subscribers[getStreamId(id, streamType)] if !found { + client := s.getClientUnlocked() s.mu.Unlock() var err error - subscriber, err = mcu.NewSubscriber(ctx, s, id, streamType) + subscriber, err = mcu.NewSubscriber(ctx, s, id, streamType, client) s.mu.Lock() if err != nil { return nil, err diff --git a/clientsession_test.go b/clientsession_test.go index 43de0e4..6d3b9a4 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -131,10 +131,13 @@ func TestBandwidth_Client(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() @@ -148,9 +151,6 @@ func TestBandwidth_Client(t *testing.T) { t.Fatal(err) } - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - hello, err := client.RunUntilHello(ctx) if err != nil { t.Fatal(err) @@ -217,10 +217,13 @@ func TestBandwidth_Backend(t *testing.T) { backend.maxScreenBitrate = 1000 backend.maxStreamBitrate = 2000 + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() @@ -232,9 +235,6 @@ func TestBandwidth_Backend(t *testing.T) { StreamTypeScreen, } - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - for _, streamType := range streamTypes { t.Run(string(streamType), func(t *testing.T) { client := NewTestClient(t, server, hub) diff --git a/docker/README.md b/docker/README.md index c19e078..28c76a1 100644 --- a/docker/README.md +++ b/docker/README.md @@ -100,6 +100,11 @@ The running container can be configured through different environment variables: - `CONFIG`: Optional name of configuration file to use. - `HTTP_LISTEN`: Address of HTTP listener. - `COUNTRY`: Optional ISO 3166 country this proxy is located at. +- `EXTERNAL_HOSTNAME`: The external hostname for remote streams. Will try to autodetect if omitted. +- `TOKEN_ID`: Id of the token to use when connecting remote streams. +- `TOKEN_KEY`: Private key for the configured token id. +- `BANDWIDTH_INCOMING`: Optional incoming target bandwidth (in megabits per second). +- `BANDWIDTH_OUTGOING`: Optional outgoing target bandwidth (in megabits per second). - `JANUS_URL`: Url to Janus server. - `MAX_STREAM_BITRATE`: Optional maximum bitrate for audio/video streams. - `MAX_SCREEN_BITRATE`: Optional maximum bitrate for screensharing streams. diff --git a/docker/proxy/entrypoint.sh b/docker/proxy/entrypoint.sh index 31e37f0..6eba347 100755 --- a/docker/proxy/entrypoint.sh +++ b/docker/proxy/entrypoint.sh @@ -44,6 +44,21 @@ if [ ! -f "$CONFIG" ]; then sed -i "s|#country =.*|country = $COUNTRY|" "$CONFIG" fi + if [ -n "$EXTERNAL_HOSTNAME" ]; then + sed -i "s|#hostname =.*|hostname = $EXTERNAL_HOSTNAME|" "$CONFIG" + fi + if [ -n "$TOKEN_ID" ]; then + sed -i "s|#token_id =.*|token_id = $TOKEN_ID|" "$CONFIG" + fi + if [ -n "$TOKEN_KEY" ]; then + sed -i "s|#token_key =.*|token_key = $TOKEN_KEY|" "$CONFIG" + if [ -n "$BANDWIDTH_INCOMING" ]; then + sed -i "s|#incoming =.*|incoming = $BANDWIDTH_INCOMING|" "$CONFIG" + fi + if [ -n "$BANDWIDTH_OUTGOING" ]; then + sed -i "s|#outgoing =.*|outgoing = $BANDWIDTH_OUTGOING|" "$CONFIG" + fi + HAS_ETCD= if [ -n "$ETCD_ENDPOINTS" ]; then sed -i "s|#endpoints =.*|endpoints = $ETCD_ENDPOINTS|" "$CONFIG" diff --git a/grpc_server.go b/grpc_server.go index 236467d..0e1d30e 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -55,6 +55,14 @@ func init() { GrpcServerId = hex.EncodeToString(md.Sum(nil)) } +type GrpcServerHub interface { + GetSessionByResumeId(resumeId string) Session + GetSessionByPublicId(sessionId string) Session + GetSessionIdByRoomSessionId(roomSessionId string) (string, error) + + GetBackend(u *url.URL) *Backend +} + type GrpcServer struct { UnimplementedRpcBackendServer UnimplementedRpcInternalServer @@ -66,7 +74,7 @@ type GrpcServer struct { listener net.Listener serverId string // can be overwritten from tests - hub *Hub + hub GrpcServerHub } func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) { @@ -131,7 +139,7 @@ func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSession statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc() // TODO: Remove debug logging log.Printf("Lookup session id for room session id %s", request.RoomSessionId) - sid, err := s.hub.roomSessions.GetSessionId(request.RoomSessionId) + sid, err := s.hub.GetSessionIdByRoomSessionId(request.RoomSessionId) if errors.Is(err, ErrNoSuchRoomSession) { return nil, status.Error(codes.NotFound, "no such room session id") } else if err != nil { @@ -221,7 +229,7 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCou return nil, status.Error(codes.InvalidArgument, "invalid url") } - backend := s.hub.backend.GetBackend(u) + backend := s.hub.GetBackend(u) if backend == nil { return nil, status.Error(codes.NotFound, "no such backend") } @@ -233,13 +241,18 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCou func (s *GrpcServer) ProxySession(request RpcSessions_ProxySessionServer) error { statsGrpcServerCalls.WithLabelValues("ProxySession").Inc() - client, err := newRemoteGrpcClient(s.hub, request) + hub, ok := s.hub.(*Hub) + if !ok { + return status.Error(codes.Internal, "invalid hub type") + + } + client, err := newRemoteGrpcClient(hub, request) if err != nil { return err } - sid := s.hub.registerClient(client) - defer s.hub.unregisterClient(sid) + sid := hub.registerClient(client) + defer hub.unregisterClient(sid) return client.run() } diff --git a/hub.go b/hub.go index 5684f3c..77a86c7 100644 --- a/hub.go +++ b/hub.go @@ -38,6 +38,7 @@ import ( "log" "net" "net/http" + "net/url" "strings" "sync" "sync/atomic" @@ -623,6 +624,10 @@ func (h *Hub) GetSessionByResumeId(resumeId string) Session { return session } +func (h *Hub) GetSessionIdByRoomSessionId(roomSessionId string) (string, error) { + return h.roomSessions.GetSessionId(roomSessionId) +} + func (h *Hub) GetDialoutSession(roomId string, backend *Backend) *ClientSession { url := backend.Url() @@ -641,6 +646,10 @@ func (h *Hub) GetDialoutSession(roomId string, backend *Backend) *ClientSession return nil } +func (h *Hub) GetBackend(u *url.URL) *Backend { + return h.backend.GetBackend(u) +} + func (h *Hub) checkExpiredSessions(now time.Time) { for session, expires := range h.expiredSessions { if now.After(expires) { diff --git a/hub_test.go b/hub_test.go index d6ab70b..2ebc64f 100644 --- a/hub_test.go +++ b/hub_test.go @@ -4029,19 +4029,19 @@ func TestClientSendOfferPermissions(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() @@ -4170,19 +4170,19 @@ func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() @@ -4263,19 +4263,19 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() @@ -4392,19 +4392,19 @@ func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() @@ -4539,10 +4539,13 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) } + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() @@ -4550,9 +4553,6 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { hub1.SetMcu(mcu) hub2.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() @@ -4965,10 +4965,13 @@ func TestClientSendOffer(t *testing.T) { hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) } + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() @@ -4976,9 +4979,6 @@ func TestClientSendOffer(t *testing.T) { hub1.SetMcu(mcu) hub2.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() @@ -5073,19 +5073,19 @@ func TestClientUnshareScreen(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() diff --git a/janus_client.go b/janus_client.go index 0865f45..b7b33a5 100644 --- a/janus_client.go +++ b/janus_client.go @@ -258,8 +258,8 @@ type JanusGateway struct { // return gateway, nil // } -func NewJanusGateway(wsURL string, listener GatewayListener) (*JanusGateway, error) { - conn, _, err := janusDialer.Dial(wsURL, nil) +func NewJanusGateway(ctx context.Context, wsURL string, listener GatewayListener) (*JanusGateway, error) { + conn, _, err := janusDialer.DialContext(ctx, wsURL, nil) if err != nil { return nil, err } diff --git a/mcu_common.go b/mcu_common.go index 3bea933..8ac820c 100644 --- a/mcu_common.go +++ b/mcu_common.go @@ -66,7 +66,7 @@ type McuInitiator interface { } type Mcu interface { - Start() error + Start(ctx context.Context) error Stop() Reload(config *goconf.ConfigFile) @@ -76,7 +76,48 @@ type Mcu interface { GetStats() interface{} NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) - NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) + NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) +} + +// PublisherStream contains the available properties when creating a +// remote publisher in Janus. +type PublisherStream struct { + Mid string `json:"mid"` + Mindex int `json:"mindex"` + Type string `json:"type"` + + Description string `json:"description,omitempty"` + Disabled bool `json:"disabled,omitempty"` + + // For types "audio" and "video" + Codec string `json:"codec,omitempty"` + + // For type "audio" + Stereo bool `json:"stereo,omitempty"` + Fec bool `json:"fec,omitempty"` + Dtx bool `json:"dtx,omitempty"` + + // For type "video" + Simulcast bool `json:"simulcast,omitempty"` + Svc bool `json:"svc,omitempty"` + + ProfileH264 string `json:"h264_profile,omitempty"` + ProfileVP9 string `json:"vp9_profile,omitempty"` + + ExtIdVideoOrientation int `json:"videoorient_ext_id,omitempty"` + ExtIdPlayoutDelay int `json:"playoutdelay_ext_id,omitempty"` +} + +type RemotePublisherController interface { + PublisherId() string + + StartPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error + GetStreams(ctx context.Context) ([]PublisherStream, error) +} + +type RemoteMcu interface { + NewRemotePublisher(ctx context.Context, listener McuListener, controller RemotePublisherController, streamType StreamType) (McuRemotePublisher, error) + NewRemoteSubscriber(ctx context.Context, listener McuListener, publisher McuRemotePublisher) (McuRemoteSubscriber, error) } type StreamType string @@ -116,6 +157,10 @@ type McuPublisher interface { HasMedia(MediaType) bool SetMedia(MediaType) + + GetStreams(ctx context.Context) ([]PublisherStream, error) + PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error + UnpublishRemote(ctx context.Context, remoteId string) error } type McuSubscriber interface { @@ -123,3 +168,18 @@ type McuSubscriber interface { Publisher() string } + +type McuRemotePublisherProperties interface { + Port() int + RtcpPort() int +} + +type McuRemotePublisher interface { + McuClient + + McuRemotePublisherProperties +} + +type McuRemoteSubscriber interface { + McuSubscriber +} diff --git a/mcu_common_test.go b/mcu_common_test.go index 0609ef3..6304638 100644 --- a/mcu_common_test.go +++ b/mcu_common_test.go @@ -28,3 +28,43 @@ import ( func TestCommonMcuStats(t *testing.T) { collectAndLint(t, commonMcuStats...) } + +type MockMcuListener struct { + publicId string +} + +func (m *MockMcuListener) PublicId() string { + return m.publicId +} + +func (m *MockMcuListener) OnUpdateOffer(client McuClient, offer map[string]interface{}) { + +} + +func (m *MockMcuListener) OnIceCandidate(client McuClient, candidate interface{}) { + +} + +func (m *MockMcuListener) OnIceCompleted(client McuClient) { + +} + +func (m *MockMcuListener) SubscriberSidUpdated(subscriber McuSubscriber) { + +} + +func (m *MockMcuListener) PublisherClosed(publisher McuPublisher) { + +} + +func (m *MockMcuListener) SubscriberClosed(subscriber McuSubscriber) { + +} + +type MockMcuInitiator struct { + country string +} + +func (m *MockMcuInitiator) Country() string { + return m.country +} diff --git a/mcu_janus.go b/mcu_janus.go index 6048a7c..0f70328 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -23,11 +23,10 @@ package signaling import ( "context" - "database/sql" "encoding/json" + "errors" "fmt" "log" - "reflect" "strconv" "sync" "sync/atomic" @@ -53,6 +52,8 @@ const ( ) var ( + ErrRemoteStreamsNotSupported = errors.New("Need Janus 1.1.0 for remote streams") + streamTypeUserIds = map[StreamType]uint64{ StreamTypeVideo: videoPublisherUserId, StreamTypeScreen: screenPublisherUserId, @@ -143,6 +144,7 @@ type mcuJanus struct { gw *JanusGateway session *JanusSession handle *JanusHandle + version int closeChan chan struct{} @@ -154,6 +156,7 @@ type mcuJanus struct { publishers map[string]*mcuJanusPublisher publisherCreated Notifier publisherConnected Notifier + remotePublishers map[string]*mcuJanusRemotePublisher reconnectTimer *time.Timer reconnectInterval time.Duration @@ -166,7 +169,7 @@ type mcuJanus struct { func emptyOnConnected() {} func emptyOnDisconnected() {} -func NewMcuJanus(url string, config *goconf.ConfigFile) (Mcu, error) { +func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mcu, error) { maxStreamBitrate, _ := config.GetInt("mcu", "maxstreambitrate") if maxStreamBitrate <= 0 { maxStreamBitrate = defaultMaxStreamBitrate @@ -189,16 +192,19 @@ func NewMcuJanus(url string, config *goconf.ConfigFile) (Mcu, error) { closeChan: make(chan struct{}, 1), clients: make(map[clientInterface]bool), - publishers: make(map[string]*mcuJanusPublisher), + publishers: make(map[string]*mcuJanusPublisher), + remotePublishers: make(map[string]*mcuJanusRemotePublisher), reconnectInterval: initialReconnectInterval, } mcu.onConnected.Store(emptyOnConnected) mcu.onDisconnected.Store(emptyOnDisconnected) - mcu.reconnectTimer = time.AfterFunc(mcu.reconnectInterval, mcu.doReconnect) + mcu.reconnectTimer = time.AfterFunc(mcu.reconnectInterval, func() { + mcu.doReconnect(context.Background()) + }) mcu.reconnectTimer.Stop() - if err := mcu.reconnect(); err != nil { + if err := mcu.reconnect(ctx); err != nil { return nil, err } return mcu, nil @@ -226,9 +232,9 @@ func (m *mcuJanus) disconnect() { } } -func (m *mcuJanus) reconnect() error { +func (m *mcuJanus) reconnect(ctx context.Context) error { m.disconnect() - gw, err := NewJanusGateway(m.url, m) + gw, err := NewJanusGateway(ctx, m.url, m) if err != nil { return err } @@ -238,12 +244,12 @@ func (m *mcuJanus) reconnect() error { return nil } -func (m *mcuJanus) doReconnect() { - if err := m.reconnect(); err != nil { +func (m *mcuJanus) doReconnect(ctx context.Context) { + if err := m.reconnect(ctx); err != nil { m.scheduleReconnect(err) return } - if err := m.Start(); err != nil { + if err := m.Start(ctx); err != nil { m.scheduleReconnect(err) return } @@ -288,8 +294,11 @@ func (m *mcuJanus) isMultistream() bool { return m.version >= 1000 } -func (m *mcuJanus) Start() error { - ctx := context.TODO() +func (m *mcuJanus) hasRemotePublisher() bool { + return m.version >= 1100 +} + +func (m *mcuJanus) Start(ctx context.Context) error { info, err := m.gw.Info(ctx) if err != nil { return err @@ -356,7 +365,7 @@ loop: for { select { case <-ticker.C: - m.sendKeepalive() + m.sendKeepalive(context.Background()) case <-m.closeChan: break loop } @@ -422,8 +431,7 @@ func (m *mcuJanus) GetStats() interface{} { return result } -func (m *mcuJanus) sendKeepalive() { - ctx := context.TODO() +func (m *mcuJanus) sendKeepalive(ctx context.Context) { if _, err := m.session.KeepAlive(ctx); err != nil { log.Println("Could not send keepalive request", err) if e, ok := err.(*janus.ErrorMsg); ok { @@ -435,272 +443,6 @@ func (m *mcuJanus) sendKeepalive() { } } -type mcuJanusClient struct { - mcu *mcuJanus - listener McuListener - mu sync.Mutex // nolint - - id uint64 - session uint64 - roomId uint64 - sid string - streamType StreamType - maxBitrate int - - handle *JanusHandle - handleId uint64 - closeChan chan struct{} - deferred chan func() - - handleEvent func(event *janus.EventMsg) - handleHangup func(event *janus.HangupMsg) - handleDetached func(event *janus.DetachedMsg) - handleConnected func(event *janus.WebRTCUpMsg) - handleSlowLink func(event *janus.SlowLinkMsg) - handleMedia func(event *janus.MediaMsg) -} - -func (c *mcuJanusClient) Id() string { - return strconv.FormatUint(c.id, 10) -} - -func (c *mcuJanusClient) Sid() string { - return c.sid -} - -func (c *mcuJanusClient) StreamType() StreamType { - return c.streamType -} - -func (c *mcuJanusClient) MaxBitrate() int { - return c.maxBitrate -} - -func (c *mcuJanusClient) Close(ctx context.Context) { -} - -func (c *mcuJanusClient) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { -} - -func (c *mcuJanusClient) closeClient(ctx context.Context) bool { - if handle := c.handle; handle != nil { - c.handle = nil - close(c.closeChan) - if _, err := handle.Detach(ctx); err != nil { - if e, ok := err.(*janus.ErrorMsg); !ok || e.Err.Code != JANUS_ERROR_HANDLE_NOT_FOUND { - log.Println("Could not detach client", handle.Id, err) - } - } - return true - } - - return false -} - -func (c *mcuJanusClient) run(handle *JanusHandle, closeChan <-chan struct{}) { -loop: - for { - select { - case msg := <-handle.Events: - switch t := msg.(type) { - case *janus.EventMsg: - c.handleEvent(t) - case *janus.HangupMsg: - c.handleHangup(t) - case *janus.DetachedMsg: - c.handleDetached(t) - case *janus.MediaMsg: - c.handleMedia(t) - case *janus.WebRTCUpMsg: - c.handleConnected(t) - case *janus.SlowLinkMsg: - c.handleSlowLink(t) - case *TrickleMsg: - c.handleTrickle(t) - default: - log.Println("Received unsupported event type", msg, reflect.TypeOf(msg)) - } - case f := <-c.deferred: - f() - case <-closeChan: - break loop - } - } -} - -func (c *mcuJanusClient) sendOffer(ctx context.Context, offer map[string]interface{}, callback func(error, map[string]interface{})) { - handle := c.handle - if handle == nil { - callback(ErrNotConnected, nil) - return - } - - configure_msg := map[string]interface{}{ - "request": "configure", - "audio": true, - "video": true, - "data": true, - } - answer_msg, err := handle.Message(ctx, configure_msg, offer) - if err != nil { - callback(err, nil) - return - } - - callback(nil, answer_msg.Jsep) -} - -func (c *mcuJanusClient) sendAnswer(ctx context.Context, answer map[string]interface{}, callback func(error, map[string]interface{})) { - handle := c.handle - if handle == nil { - callback(ErrNotConnected, nil) - return - } - - start_msg := map[string]interface{}{ - "request": "start", - "room": c.roomId, - } - start_response, err := handle.Message(ctx, start_msg, answer) - if err != nil { - callback(err, nil) - return - } - log.Println("Started listener", start_response) - callback(nil, nil) -} - -func (c *mcuJanusClient) sendCandidate(ctx context.Context, candidate interface{}, callback func(error, map[string]interface{})) { - handle := c.handle - if handle == nil { - callback(ErrNotConnected, nil) - return - } - - if _, err := handle.Trickle(ctx, candidate); err != nil { - callback(err, nil) - return - } - callback(nil, nil) -} - -func (c *mcuJanusClient) handleTrickle(event *TrickleMsg) { - if event.Candidate.Completed { - c.listener.OnIceCompleted(c) - } else { - c.listener.OnIceCandidate(c, event.Candidate) - } -} - -func (c *mcuJanusClient) selectStream(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { - handle := c.handle - if handle == nil { - callback(ErrNotConnected, nil) - return - } - - if stream == nil || !stream.HasValues() { - callback(nil, nil) - return - } - - configure_msg := map[string]interface{}{ - "request": "configure", - } - if stream != nil { - stream.AddToMessage(configure_msg) - } - _, err := handle.Message(ctx, configure_msg, nil) - if err != nil { - callback(err, nil) - return - } - - callback(nil, nil) -} - -type publisherStatsCounter struct { - mu sync.Mutex - - streamTypes map[StreamType]bool - subscribers map[string]bool -} - -func (c *publisherStatsCounter) Reset() { - c.mu.Lock() - defer c.mu.Unlock() - - count := len(c.subscribers) - for streamType := range c.streamTypes { - statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(count)) - } - c.streamTypes = nil - c.subscribers = nil -} - -func (c *publisherStatsCounter) EnableStream(streamType StreamType, enable bool) { - c.mu.Lock() - defer c.mu.Unlock() - - if enable == c.streamTypes[streamType] { - return - } - - if enable { - if c.streamTypes == nil { - c.streamTypes = make(map[StreamType]bool) - } - c.streamTypes[streamType] = true - statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Inc() - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Add(float64(len(c.subscribers))) - } else { - delete(c.streamTypes, streamType) - statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(len(c.subscribers))) - } -} - -func (c *publisherStatsCounter) AddSubscriber(id string) { - c.mu.Lock() - defer c.mu.Unlock() - - if c.subscribers[id] { - return - } - - if c.subscribers == nil { - c.subscribers = make(map[string]bool) - } - c.subscribers[id] = true - for streamType := range c.streamTypes { - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Inc() - } -} - -func (c *publisherStatsCounter) RemoveSubscriber(id string) { - c.mu.Lock() - defer c.mu.Unlock() - - if !c.subscribers[id] { - return - } - - delete(c.subscribers, id) - for streamType := range c.streamTypes { - statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() - } -} - -type mcuJanusPublisher struct { - mcuJanusClient - - id string - bitrate int - mediaTypes MediaType - stats publisherStatsCounter -} - func (m *mcuJanus) SubscriberConnected(id string, publisher string, streamType StreamType) { m.mu.Lock() defer m.mu.Unlock() @@ -719,17 +461,7 @@ func (m *mcuJanus) SubscriberDisconnected(id string, publisher string, streamTyp } } -func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType StreamType, bitrate int) (*JanusHandle, uint64, uint64, int, error) { - session := m.session - if session == nil { - return nil, 0, 0, 0, ErrNotConnected - } - handle, err := session.Attach(ctx, pluginVideoRoom) - if err != nil { - return nil, 0, 0, 0, err - } - - log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id) +func (m *mcuJanus) createPublisherRoom(ctx context.Context, handle *JanusHandle, id string, streamType StreamType, bitrate int) (uint64, int, error) { create_msg := map[string]interface{}{ "request": "create", "description": getStreamId(id, streamType), @@ -756,7 +488,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st if _, err2 := handle.Detach(ctx); err2 != nil { log.Printf("Error detaching handle %d: %s", handle.Id, err2) } - return nil, 0, 0, 0, err + return 0, 0, err } roomId := getPluginIntValue(create_response.PluginData, pluginVideoRoom, "room") @@ -764,10 +496,32 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st if _, err := handle.Detach(ctx); err != nil { log.Printf("Error detaching handle %d: %s", handle.Id, err) } - return nil, 0, 0, 0, fmt.Errorf("No room id received: %+v", create_response) + return 0, 0, fmt.Errorf("No room id received: %+v", create_response) } log.Println("Created room", roomId, create_response.PluginData) + return roomId, bitrate, nil +} + +func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType StreamType, bitrate int) (*JanusHandle, uint64, uint64, int, error) { + session := m.session + if session == nil { + return nil, 0, 0, 0, ErrNotConnected + } + handle, err := session.Attach(ctx, pluginVideoRoom) + if err != nil { + return nil, 0, 0, 0, err + } + + log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id) + + roomId, bitrate, err := m.createPublisherRoom(ctx, handle, id, streamType, bitrate) + if err != nil { + if _, err2 := handle.Detach(ctx); err2 != nil { + log.Printf("Error detaching handle %d: %s", handle.Id, err2) + } + return nil, 0, 0, 0, err + } msg := map[string]interface{}{ "request": "join", @@ -814,6 +568,7 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id st closeChan: make(chan struct{}, 1), deferred: make(chan func(), 64), }, + sdpReady: NewCloser(), id: id, bitrate: bitrate, mediaTypes: mediaTypes, @@ -837,150 +592,6 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id st return client, nil } -func (p *mcuJanusPublisher) handleEvent(event *janus.EventMsg) { - if videoroom := getPluginStringValue(event.Plugindata, pluginVideoRoom, "videoroom"); videoroom != "" { - ctx := context.TODO() - switch videoroom { - case "destroyed": - log.Printf("Publisher %d: associated room has been destroyed, closing", p.handleId) - go p.Close(ctx) - case "slow_link": - // Ignore, processed through "handleSlowLink" in the general events. - default: - log.Printf("Unsupported videoroom publisher event in %d: %+v", p.handleId, event) - } - } else { - log.Printf("Unsupported publisher event in %d: %+v", p.handleId, event) - } -} - -func (p *mcuJanusPublisher) handleHangup(event *janus.HangupMsg) { - log.Printf("Publisher %d received hangup (%s), closing", p.handleId, event.Reason) - go p.Close(context.Background()) -} - -func (p *mcuJanusPublisher) handleDetached(event *janus.DetachedMsg) { - log.Printf("Publisher %d received detached, closing", p.handleId) - go p.Close(context.Background()) -} - -func (p *mcuJanusPublisher) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Publisher %d received connected", p.handleId) - p.mcu.publisherConnected.Notify(getStreamId(p.id, p.streamType)) -} - -func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) { - if event.Uplink { - log.Printf("Publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) - } else { - log.Printf("Publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) - } -} - -func (p *mcuJanusPublisher) handleMedia(event *janus.MediaMsg) { - mediaType := StreamType(event.Type) - if mediaType == StreamTypeVideo && p.streamType == StreamTypeScreen { - // We want to differentiate between audio, video and screensharing - mediaType = p.streamType - } - - p.stats.EnableStream(mediaType, event.Receiving) -} - -func (p *mcuJanusPublisher) HasMedia(mt MediaType) bool { - return (p.mediaTypes & mt) == mt -} - -func (p *mcuJanusPublisher) SetMedia(mt MediaType) { - p.mediaTypes = mt -} - -func (p *mcuJanusPublisher) NotifyReconnected() { - ctx := context.TODO() - handle, session, roomId, _, err := p.mcu.getOrCreatePublisherHandle(ctx, p.id, p.streamType, p.bitrate) - if err != nil { - log.Printf("Could not reconnect publisher %s: %s", p.id, err) - // TODO(jojo): Retry - return - } - - p.handle = handle - p.handleId = handle.Id - p.session = session - p.roomId = roomId - - log.Printf("Publisher %s reconnected on handle %d", p.id, p.handleId) -} - -func (p *mcuJanusPublisher) Close(ctx context.Context) { - notify := false - p.mu.Lock() - if handle := p.handle; handle != nil && p.roomId != 0 { - destroy_msg := map[string]interface{}{ - "request": "destroy", - "room": p.roomId, - } - if _, err := handle.Request(ctx, destroy_msg); err != nil { - log.Printf("Error destroying room %d: %s", p.roomId, err) - } else { - log.Printf("Room %d destroyed", p.roomId) - } - p.mcu.mu.Lock() - delete(p.mcu.publishers, getStreamId(p.id, p.streamType)) - p.mcu.mu.Unlock() - p.roomId = 0 - notify = true - } - p.closeClient(ctx) - p.mu.Unlock() - - p.stats.Reset() - - if notify { - statsPublishersCurrent.WithLabelValues(string(p.streamType)).Dec() - p.mcu.unregisterClient(p) - p.listener.PublisherClosed(p) - } - p.mcuJanusClient.Close(ctx) -} - -func (p *mcuJanusPublisher) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { - statsMcuMessagesTotal.WithLabelValues(data.Type).Inc() - jsep_msg := data.Payload - switch data.Type { - case "offer": - p.deferred <- func() { - msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) - defer cancel() - - // TODO Tear down previous publisher and get a new one if sid does - // not match? - p.sendOffer(msgctx, jsep_msg, callback) - } - case "candidate": - p.deferred <- func() { - msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) - defer cancel() - - if data.Sid == "" || data.Sid == p.Sid() { - p.sendCandidate(msgctx, jsep_msg["candidate"], callback) - } else { - go callback(fmt.Errorf("Candidate message sid (%s) does not match publisher sid (%s)", data.Sid, p.Sid()), nil) - } - } - case "endOfCandidates": - // Ignore - default: - go callback(fmt.Errorf("Unsupported message type: %s", data.Type), nil) - } -} - -type mcuJanusSubscriber struct { - mcuJanusClient - - publisher string -} - func (m *mcuJanus) getPublisher(ctx context.Context, publisher string, streamType StreamType) (*mcuJanusPublisher, error) { // Do the direct check immediately as this should be the normal case. key := getStreamId(publisher, streamType) @@ -1029,7 +640,7 @@ func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher st return handle, pub, nil } -func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { +func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) { if _, found := streamTypeUserIds[streamType]; !found { return nil, fmt.Errorf("Unsupported stream type %s", streamType) } @@ -1070,369 +681,167 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ return client, nil } -func (p *mcuJanusSubscriber) Publisher() string { - return p.publisher -} - -func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) { - if videoroom := getPluginStringValue(event.Plugindata, pluginVideoRoom, "videoroom"); videoroom != "" { - ctx := context.TODO() - switch videoroom { - case "destroyed": - log.Printf("Subscriber %d: associated room has been destroyed, closing", p.handleId) - go p.Close(ctx) - case "event": - // Handle renegotiations, but ignore other events like selected - // substream / temporal layer. - if getPluginStringValue(event.Plugindata, pluginVideoRoom, "configured") == "ok" && - event.Jsep != nil && event.Jsep["type"] == "offer" && event.Jsep["sdp"] != nil { - p.listener.OnUpdateOffer(p, event.Jsep) - } - case "slow_link": - // Ignore, processed through "handleSlowLink" in the general events. - default: - log.Printf("Unsupported videoroom event %s for subscriber %d: %+v", videoroom, p.handleId, event) - } - } else { - log.Printf("Unsupported event for subscriber %d: %+v", p.handleId, event) +func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller RemotePublisherController, streamType StreamType, bitrate int) (*mcuJanusRemotePublisher, error) { + m.mu.Lock() + defer m.mu.Unlock() + pub, found := m.remotePublishers[getStreamId(controller.PublisherId(), streamType)] + if found { + return pub, nil } -} -func (p *mcuJanusSubscriber) handleHangup(event *janus.HangupMsg) { - log.Printf("Subscriber %d received hangup (%s), closing", p.handleId, event.Reason) - go p.Close(context.Background()) -} - -func (p *mcuJanusSubscriber) handleDetached(event *janus.DetachedMsg) { - log.Printf("Subscriber %d received detached, closing", p.handleId) - go p.Close(context.Background()) -} - -func (p *mcuJanusSubscriber) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Subscriber %d received connected", p.handleId) - p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType) -} - -func (p *mcuJanusSubscriber) handleSlowLink(event *janus.SlowLinkMsg) { - if event.Uplink { - log.Printf("Subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) - } else { - log.Printf("Subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) - } -} - -func (p *mcuJanusSubscriber) handleMedia(event *janus.MediaMsg) { - // Only triggered for publishers -} - -func (p *mcuJanusSubscriber) NotifyReconnected() { - ctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) - defer cancel() - handle, pub, err := p.mcu.getOrCreateSubscriberHandle(ctx, p.publisher, p.streamType) + streams, err := controller.GetStreams(ctx) if err != nil { - // TODO(jojo): Retry? - log.Printf("Could not reconnect subscriber for publisher %s: %s", p.publisher, err) - p.Close(context.Background()) - return + return nil, err } - p.handle = handle - p.handleId = handle.Id - p.roomId = pub.roomId - p.sid = strconv.FormatUint(handle.Id, 10) - p.listener.SubscriberSidUpdated(p) - log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId) -} - -func (p *mcuJanusSubscriber) Close(ctx context.Context) { - p.mu.Lock() - closed := p.closeClient(ctx) - p.mu.Unlock() - - if closed { - p.mcu.SubscriberDisconnected(p.Id(), p.publisher, p.streamType) - statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Dec() - } - p.mcu.unregisterClient(p) - p.listener.SubscriberClosed(p) - p.mcuJanusClient.Close(ctx) -} - -func (p *mcuJanusSubscriber) joinRoom(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { - handle := p.handle - if handle == nil { - callback(ErrNotConnected, nil) - return + if len(streams) == 0 { + return nil, errors.New("remote publisher has no streams") } - waiter := p.mcu.publisherConnected.NewWaiter(getStreamId(p.publisher, p.streamType)) - defer p.mcu.publisherConnected.Release(waiter) - - loggedNotPublishingYet := false -retry: - join_msg := map[string]interface{}{ - "request": "join", - "ptype": "subscriber", - "room": p.roomId, + session := m.session + if session == nil { + return nil, ErrNotConnected } - if p.mcu.isMultistream() { - join_msg["streams"] = []map[string]interface{}{ - { - "feed": streamTypeUserIds[p.streamType], + + handle, err := session.Attach(ctx, pluginVideoRoom) + if err != nil { + return nil, err + } + + roomId, bitrate, err := m.createPublisherRoom(ctx, handle, controller.PublisherId(), streamType, bitrate) + if err != nil { + if _, err2 := handle.Detach(ctx); err2 != nil { + log.Printf("Error detaching handle %d: %s", handle.Id, err2) + } + return nil, err + } + + response, err := handle.Request(ctx, map[string]interface{}{ + "request": "add_remote_publisher", + "room": roomId, + "id": streamTypeUserIds[streamType], + "streams": streams, + }) + if err != nil { + if _, err2 := handle.Detach(ctx); err2 != nil { + log.Printf("Error detaching handle %d: %s", handle.Id, err2) + } + return nil, err + } + + id := getPluginIntValue(response.PluginData, pluginVideoRoom, "id") + port := getPluginIntValue(response.PluginData, pluginVideoRoom, "port") + rtcp_port := getPluginIntValue(response.PluginData, pluginVideoRoom, "rtcp_port") + + pub = &mcuJanusRemotePublisher{ + mcuJanusPublisher: mcuJanusPublisher{ + mcuJanusClient: mcuJanusClient{ + mcu: m, + + id: id, + session: response.Session, + roomId: roomId, + sid: strconv.FormatUint(handle.Id, 10), + streamType: streamType, + maxBitrate: bitrate, + + handle: handle, + handleId: handle.Id, + closeChan: make(chan struct{}, 1), + deferred: make(chan func(), 64), }, - } - } else { - join_msg["feed"] = streamTypeUserIds[p.streamType] + + sdpReady: NewCloser(), + id: controller.PublisherId(), + }, + + port: int(port), + rtcpPort: int(rtcp_port), } - if stream != nil { - stream.AddToMessage(join_msg) + pub.mcuJanusClient.handleEvent = pub.handleEvent + pub.mcuJanusClient.handleHangup = pub.handleHangup + pub.mcuJanusClient.handleDetached = pub.handleDetached + pub.mcuJanusClient.handleConnected = pub.handleConnected + pub.mcuJanusClient.handleSlowLink = pub.handleSlowLink + pub.mcuJanusClient.handleMedia = pub.handleMedia + + if err := controller.StartPublishing(ctx, pub); err != nil { + go pub.Close(context.Background()) + return nil, err } - join_response, err := handle.Message(ctx, join_msg, nil) + + m.remotePublishers[getStreamId(controller.PublisherId(), streamType)] = pub + + return pub, nil +} + +func (m *mcuJanus) NewRemotePublisher(ctx context.Context, listener McuListener, controller RemotePublisherController, streamType StreamType) (McuRemotePublisher, error) { + if _, found := streamTypeUserIds[streamType]; !found { + return nil, fmt.Errorf("Unsupported stream type %s", streamType) + } + + if !m.hasRemotePublisher() { + return nil, ErrRemoteStreamsNotSupported + } + + pub, err := m.getOrCreateRemotePublisher(ctx, controller, streamType, 0) if err != nil { - callback(err, nil) - return + return nil, err } - if error_code := getPluginIntValue(join_response.Plugindata, pluginVideoRoom, "error_code"); error_code > 0 { - switch error_code { - case JANUS_VIDEOROOM_ERROR_ALREADY_JOINED: - // The subscriber is already connected to the room. This can happen - // if a client leaves a call but keeps the subscriber objects active. - // On joining the call again, the subscriber tries to join on the - // MCU which will fail because he is still connected. - // To get a new Offer SDP, we have to tear down the session on the - // MCU and join again. - p.mu.Lock() - p.closeClient(ctx) - p.mu.Unlock() - - var pub *mcuJanusPublisher - handle, pub, err = p.mcu.getOrCreateSubscriberHandle(ctx, p.publisher, p.streamType) - if err != nil { - // Reconnection didn't work, need to unregister/remove subscriber - // so a new object will be created if the request is retried. - p.mcu.unregisterClient(p) - p.listener.SubscriberClosed(p) - callback(fmt.Errorf("Already connected as subscriber for %s, error during re-joining: %s", p.streamType, err), nil) - return - } - - p.handle = handle - p.handleId = handle.Id - p.roomId = pub.roomId - p.sid = strconv.FormatUint(handle.Id, 10) - p.listener.SubscriberSidUpdated(p) - p.closeChan = make(chan struct{}, 1) - go p.run(p.handle, p.closeChan) - log.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId) - goto retry - case JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM: - fallthrough - case JANUS_VIDEOROOM_ERROR_NO_SUCH_FEED: - switch error_code { - case JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM: - log.Printf("Publisher %s not created yet for %s, wait and retry to join room %d as subscriber", p.publisher, p.streamType, p.roomId) - case JANUS_VIDEOROOM_ERROR_NO_SUCH_FEED: - log.Printf("Publisher %s not sending yet for %s, wait and retry to join room %d as subscriber", p.publisher, p.streamType, p.roomId) - } - - if !loggedNotPublishingYet { - loggedNotPublishingYet = true - statsWaitingForPublisherTotal.WithLabelValues(string(p.streamType)).Inc() - } - - if err := waiter.Wait(ctx); err != nil { - callback(err, nil) - return - } - log.Printf("Retry subscribing %s from %s", p.streamType, p.publisher) - goto retry - default: - // TODO(jojo): Should we handle other errors, too? - callback(fmt.Errorf("Error joining room as subscriber: %+v", join_response), nil) - return - } - } - //log.Println("Joined as listener", join_response) - - p.session = join_response.Session - callback(nil, join_response.Jsep) + pub.addRef() + return pub, nil } -func (p *mcuJanusSubscriber) update(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { - handle := p.handle - if handle == nil { - callback(ErrNotConnected, nil) - return +func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener, publisher McuRemotePublisher) (McuRemoteSubscriber, error) { + pub, ok := publisher.(*mcuJanusRemotePublisher) + if !ok { + return nil, errors.New("unsupported remote publisher") } - configure_msg := map[string]interface{}{ - "request": "configure", - "update": true, + session := m.session + if session == nil { + return nil, ErrNotConnected } - if stream != nil { - stream.AddToMessage(configure_msg) - } - configure_response, err := handle.Message(ctx, configure_msg, nil) + + handle, err := session.Attach(ctx, pluginVideoRoom) if err != nil { - callback(err, nil) - return + return nil, err } - callback(nil, configure_response.Jsep) -} - -type streamSelection struct { - substream sql.NullInt16 - temporal sql.NullInt16 - audio sql.NullBool - video sql.NullBool -} - -func (s *streamSelection) HasValues() bool { - return s.substream.Valid || s.temporal.Valid || s.audio.Valid || s.video.Valid -} - -func (s *streamSelection) AddToMessage(message map[string]interface{}) { - if s.substream.Valid { - message["substream"] = s.substream.Int16 - } - if s.temporal.Valid { - message["temporal"] = s.temporal.Int16 - } - if s.audio.Valid { - message["audio"] = s.audio.Bool - } - if s.video.Valid { - message["video"] = s.video.Bool - } -} - -func parseStreamSelection(payload map[string]interface{}) (*streamSelection, error) { - var stream streamSelection - if value, found := payload["substream"]; found { - switch value := value.(type) { - case int: - stream.substream.Valid = true - stream.substream.Int16 = int16(value) - case float32: - stream.substream.Valid = true - stream.substream.Int16 = int16(value) - case float64: - stream.substream.Valid = true - stream.substream.Int16 = int16(value) - default: - return nil, fmt.Errorf("Unsupported substream value: %v", value) - } - } - - if value, found := payload["temporal"]; found { - switch value := value.(type) { - case int: - stream.temporal.Valid = true - stream.temporal.Int16 = int16(value) - case float32: - stream.temporal.Valid = true - stream.temporal.Int16 = int16(value) - case float64: - stream.temporal.Valid = true - stream.temporal.Int16 = int16(value) - default: - return nil, fmt.Errorf("Unsupported temporal value: %v", value) - } - } - - if value, found := payload["audio"]; found { - switch value := value.(type) { - case bool: - stream.audio.Valid = true - stream.audio.Bool = value - default: - return nil, fmt.Errorf("Unsupported audio value: %v", value) - } - } - - if value, found := payload["video"]; found { - switch value := value.(type) { - case bool: - stream.video.Valid = true - stream.video.Bool = value - default: - return nil, fmt.Errorf("Unsupported video value: %v", value) - } - } - - return &stream, nil -} - -func (p *mcuJanusSubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { - statsMcuMessagesTotal.WithLabelValues(data.Type).Inc() - jsep_msg := data.Payload - switch data.Type { - case "requestoffer": - fallthrough - case "sendoffer": - p.deferred <- func() { - msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) - defer cancel() - - stream, err := parseStreamSelection(jsep_msg) - if err != nil { - go callback(err, nil) - return - } - - if data.Sid == "" || data.Sid != p.Sid() { - p.joinRoom(msgctx, stream, callback) - } else { - p.update(msgctx, stream, callback) - } - } - case "answer": - p.deferred <- func() { - msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) - defer cancel() - - if data.Sid == "" || data.Sid == p.Sid() { - p.sendAnswer(msgctx, jsep_msg, callback) - } else { - go callback(fmt.Errorf("Answer message sid (%s) does not match subscriber sid (%s)", data.Sid, p.Sid()), nil) - } - } - case "candidate": - p.deferred <- func() { - msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) - defer cancel() - - if data.Sid == "" || data.Sid == p.Sid() { - p.sendCandidate(msgctx, jsep_msg["candidate"], callback) - } else { - go callback(fmt.Errorf("Candidate message sid (%s) does not match subscriber sid (%s)", data.Sid, p.Sid()), nil) - } - } - case "endOfCandidates": - // Ignore - case "selectStream": - stream, err := parseStreamSelection(jsep_msg) - if err != nil { - go callback(err, nil) - return - } - - if stream == nil || !stream.HasValues() { - // Nothing to do - go callback(nil, nil) - return - } - - p.deferred <- func() { - msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) - defer cancel() - - p.selectStream(msgctx, stream, callback) - } - default: - // Return error asynchronously - go callback(fmt.Errorf("Unsupported message type: %s", data.Type), nil) - } + log.Printf("Attached subscriber to room %d of publisher %s in plugin %s in session %d as %d", pub.roomId, pub.id, pluginVideoRoom, session.Id, handle.Id) + + client := &mcuJanusRemoteSubscriber{ + mcuJanusSubscriber: mcuJanusSubscriber{ + mcuJanusClient: mcuJanusClient{ + mcu: m, + listener: listener, + + id: m.clientId.Add(1), + roomId: pub.roomId, + sid: strconv.FormatUint(handle.Id, 10), + streamType: publisher.StreamType(), + maxBitrate: pub.MaxBitrate(), + + handle: handle, + handleId: handle.Id, + closeChan: make(chan struct{}, 1), + deferred: make(chan func(), 64), + }, + publisher: pub.id, + }, + } + client.remote.Store(pub) + pub.addRef() + client.mcuJanusClient.handleEvent = client.handleEvent + client.mcuJanusClient.handleHangup = client.handleHangup + client.mcuJanusClient.handleDetached = client.handleDetached + client.mcuJanusClient.handleConnected = client.handleConnected + client.mcuJanusClient.handleSlowLink = client.handleSlowLink + client.mcuJanusClient.handleMedia = client.handleMedia + m.registerClient(client) + go client.run(handle, client.closeChan) + statsSubscribersCurrent.WithLabelValues(string(publisher.StreamType())).Inc() + statsSubscribersTotal.WithLabelValues(string(publisher.StreamType())).Inc() + return client, nil } diff --git a/mcu_janus_client.go b/mcu_janus_client.go new file mode 100644 index 0000000..f1d254b --- /dev/null +++ b/mcu_janus_client.go @@ -0,0 +1,216 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2017 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "log" + "reflect" + "strconv" + "sync" + + "github.com/notedit/janus-go" +) + +type mcuJanusClient struct { + mcu *mcuJanus + listener McuListener + mu sync.Mutex // nolint + + id uint64 + session uint64 + roomId uint64 + sid string + streamType StreamType + maxBitrate int + + handle *JanusHandle + handleId uint64 + closeChan chan struct{} + deferred chan func() + + handleEvent func(event *janus.EventMsg) + handleHangup func(event *janus.HangupMsg) + handleDetached func(event *janus.DetachedMsg) + handleConnected func(event *janus.WebRTCUpMsg) + handleSlowLink func(event *janus.SlowLinkMsg) + handleMedia func(event *janus.MediaMsg) +} + +func (c *mcuJanusClient) Id() string { + return strconv.FormatUint(c.id, 10) +} + +func (c *mcuJanusClient) Sid() string { + return c.sid +} + +func (c *mcuJanusClient) StreamType() StreamType { + return c.streamType +} + +func (c *mcuJanusClient) MaxBitrate() int { + return c.maxBitrate +} + +func (c *mcuJanusClient) Close(ctx context.Context) { +} + +func (c *mcuJanusClient) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { +} + +func (c *mcuJanusClient) closeClient(ctx context.Context) bool { + if handle := c.handle; handle != nil { + c.handle = nil + close(c.closeChan) + if _, err := handle.Detach(ctx); err != nil { + if e, ok := err.(*janus.ErrorMsg); !ok || e.Err.Code != JANUS_ERROR_HANDLE_NOT_FOUND { + log.Println("Could not detach client", handle.Id, err) + } + } + return true + } + + return false +} + +func (c *mcuJanusClient) run(handle *JanusHandle, closeChan <-chan struct{}) { +loop: + for { + select { + case msg := <-handle.Events: + switch t := msg.(type) { + case *janus.EventMsg: + c.handleEvent(t) + case *janus.HangupMsg: + c.handleHangup(t) + case *janus.DetachedMsg: + c.handleDetached(t) + case *janus.MediaMsg: + c.handleMedia(t) + case *janus.WebRTCUpMsg: + c.handleConnected(t) + case *janus.SlowLinkMsg: + c.handleSlowLink(t) + case *TrickleMsg: + c.handleTrickle(t) + default: + log.Println("Received unsupported event type", msg, reflect.TypeOf(msg)) + } + case f := <-c.deferred: + f() + case <-closeChan: + break loop + } + } +} + +func (c *mcuJanusClient) sendOffer(ctx context.Context, offer map[string]interface{}, callback func(error, map[string]interface{})) { + handle := c.handle + if handle == nil { + callback(ErrNotConnected, nil) + return + } + + configure_msg := map[string]interface{}{ + "request": "configure", + "audio": true, + "video": true, + "data": true, + } + answer_msg, err := handle.Message(ctx, configure_msg, offer) + if err != nil { + callback(err, nil) + return + } + + callback(nil, answer_msg.Jsep) +} + +func (c *mcuJanusClient) sendAnswer(ctx context.Context, answer map[string]interface{}, callback func(error, map[string]interface{})) { + handle := c.handle + if handle == nil { + callback(ErrNotConnected, nil) + return + } + + start_msg := map[string]interface{}{ + "request": "start", + "room": c.roomId, + } + start_response, err := handle.Message(ctx, start_msg, answer) + if err != nil { + callback(err, nil) + return + } + log.Println("Started listener", start_response) + callback(nil, nil) +} + +func (c *mcuJanusClient) sendCandidate(ctx context.Context, candidate interface{}, callback func(error, map[string]interface{})) { + handle := c.handle + if handle == nil { + callback(ErrNotConnected, nil) + return + } + + if _, err := handle.Trickle(ctx, candidate); err != nil { + callback(err, nil) + return + } + callback(nil, nil) +} + +func (c *mcuJanusClient) handleTrickle(event *TrickleMsg) { + if event.Candidate.Completed { + c.listener.OnIceCompleted(c) + } else { + c.listener.OnIceCandidate(c, event.Candidate) + } +} + +func (c *mcuJanusClient) selectStream(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { + handle := c.handle + if handle == nil { + callback(ErrNotConnected, nil) + return + } + + if stream == nil || !stream.HasValues() { + callback(nil, nil) + return + } + + configure_msg := map[string]interface{}{ + "request": "configure", + } + if stream != nil { + stream.AddToMessage(configure_msg) + } + _, err := handle.Message(ctx, configure_msg, nil) + if err != nil { + callback(err, nil) + return + } + + callback(nil, nil) +} diff --git a/mcu_janus_publisher.go b/mcu_janus_publisher.go new file mode 100644 index 0000000..b003727 --- /dev/null +++ b/mcu_janus_publisher.go @@ -0,0 +1,457 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2017 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "errors" + "fmt" + "log" + "strconv" + "strings" + "sync/atomic" + + "github.com/notedit/janus-go" + "github.com/pion/sdp/v3" +) + +const ( + ExtensionUrlPlayoutDelay = "http://www.webrtc.org/experiments/rtp-hdrext/playout-delay" + ExtensionUrlVideoOrientation = "urn:3gpp:video-orientation" +) + +const ( + sdpHasOffer = 1 + sdpHasAnswer = 2 +) + +type mcuJanusPublisher struct { + mcuJanusClient + + id string + bitrate int + mediaTypes MediaType + stats publisherStatsCounter + sdpFlags Flags + sdpReady *Closer + offerSdp atomic.Pointer[sdp.SessionDescription] + answerSdp atomic.Pointer[sdp.SessionDescription] +} + +func (p *mcuJanusPublisher) handleEvent(event *janus.EventMsg) { + if videoroom := getPluginStringValue(event.Plugindata, pluginVideoRoom, "videoroom"); videoroom != "" { + ctx := context.TODO() + switch videoroom { + case "destroyed": + log.Printf("Publisher %d: associated room has been destroyed, closing", p.handleId) + go p.Close(ctx) + case "slow_link": + // Ignore, processed through "handleSlowLink" in the general events. + default: + log.Printf("Unsupported videoroom publisher event in %d: %+v", p.handleId, event) + } + } else { + log.Printf("Unsupported publisher event in %d: %+v", p.handleId, event) + } +} + +func (p *mcuJanusPublisher) handleHangup(event *janus.HangupMsg) { + log.Printf("Publisher %d received hangup (%s), closing", p.handleId, event.Reason) + go p.Close(context.Background()) +} + +func (p *mcuJanusPublisher) handleDetached(event *janus.DetachedMsg) { + log.Printf("Publisher %d received detached, closing", p.handleId) + go p.Close(context.Background()) +} + +func (p *mcuJanusPublisher) handleConnected(event *janus.WebRTCUpMsg) { + log.Printf("Publisher %d received connected", p.handleId) + p.mcu.publisherConnected.Notify(getStreamId(p.id, p.streamType)) +} + +func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) { + if event.Uplink { + log.Printf("Publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) + } else { + log.Printf("Publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) + } +} + +func (p *mcuJanusPublisher) handleMedia(event *janus.MediaMsg) { + mediaType := StreamType(event.Type) + if mediaType == StreamTypeVideo && p.streamType == StreamTypeScreen { + // We want to differentiate between audio, video and screensharing + mediaType = p.streamType + } + + p.stats.EnableStream(mediaType, event.Receiving) +} + +func (p *mcuJanusPublisher) HasMedia(mt MediaType) bool { + return (p.mediaTypes & mt) == mt +} + +func (p *mcuJanusPublisher) SetMedia(mt MediaType) { + p.mediaTypes = mt +} + +func (p *mcuJanusPublisher) NotifyReconnected() { + ctx := context.TODO() + handle, session, roomId, _, err := p.mcu.getOrCreatePublisherHandle(ctx, p.id, p.streamType, p.bitrate) + if err != nil { + log.Printf("Could not reconnect publisher %s: %s", p.id, err) + // TODO(jojo): Retry + return + } + + p.handle = handle + p.handleId = handle.Id + p.session = session + p.roomId = roomId + + log.Printf("Publisher %s reconnected on handle %d", p.id, p.handleId) +} + +func (p *mcuJanusPublisher) Close(ctx context.Context) { + notify := false + p.mu.Lock() + if handle := p.handle; handle != nil && p.roomId != 0 { + destroy_msg := map[string]interface{}{ + "request": "destroy", + "room": p.roomId, + } + if _, err := handle.Request(ctx, destroy_msg); err != nil { + log.Printf("Error destroying room %d: %s", p.roomId, err) + } else { + log.Printf("Room %d destroyed", p.roomId) + } + p.mcu.mu.Lock() + delete(p.mcu.publishers, getStreamId(p.id, p.streamType)) + p.mcu.mu.Unlock() + p.roomId = 0 + notify = true + } + p.closeClient(ctx) + p.mu.Unlock() + + p.stats.Reset() + + if notify { + statsPublishersCurrent.WithLabelValues(string(p.streamType)).Dec() + p.mcu.unregisterClient(p) + p.listener.PublisherClosed(p) + } + p.mcuJanusClient.Close(ctx) +} + +func (p *mcuJanusPublisher) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { + statsMcuMessagesTotal.WithLabelValues(data.Type).Inc() + jsep_msg := data.Payload + switch data.Type { + case "offer": + p.deferred <- func() { + if data.offerSdp == nil { + // Should have been checked before. + go callback(errors.New("No sdp found in offer"), nil) + return + } + + p.offerSdp.Store(data.offerSdp) + p.sdpFlags.Add(sdpHasOffer) + if p.sdpFlags.Get() == sdpHasAnswer|sdpHasOffer { + p.sdpReady.Close() + } + + // TODO Tear down previous publisher and get a new one if sid does + // not match? + msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) + defer cancel() + + p.sendOffer(msgctx, jsep_msg, func(err error, jsep map[string]interface{}) { + if err != nil { + callback(err, jsep) + return + } + + sdpData, found := jsep["sdp"] + if !found { + log.Printf("No sdp found in answer %+v", jsep) + } else { + sdpString, ok := sdpData.(string) + if !ok { + log.Printf("Invalid sdp found in answer %+v", jsep) + } else { + var answerSdp sdp.SessionDescription + if err := answerSdp.UnmarshalString(sdpString); err != nil { + log.Printf("Error parsing answer sdp %+v: %s", sdpString, err) + p.answerSdp.Store(nil) + p.sdpFlags.Remove(sdpHasAnswer) + } else { + p.answerSdp.Store(&answerSdp) + p.sdpFlags.Add(sdpHasAnswer) + if p.sdpFlags.Get() == sdpHasAnswer|sdpHasOffer { + p.sdpReady.Close() + } + } + } + } + + callback(nil, jsep) + }) + } + case "candidate": + p.deferred <- func() { + msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) + defer cancel() + + if data.Sid == "" || data.Sid == p.Sid() { + p.sendCandidate(msgctx, jsep_msg["candidate"], callback) + } else { + go callback(fmt.Errorf("Candidate message sid (%s) does not match publisher sid (%s)", data.Sid, p.Sid()), nil) + } + } + case "endOfCandidates": + // Ignore + default: + go callback(fmt.Errorf("Unsupported message type: %s", data.Type), nil) + } +} + +func getFmtpValue(fmtp string, key string) (string, bool) { + parts := strings.Split(fmtp, ";") + for _, part := range parts { + kv := strings.SplitN(part, "=", 2) + if len(kv) != 2 { + continue + } + + if strings.EqualFold(strings.TrimSpace(kv[0]), key) { + return strings.TrimSpace(kv[1]), true + } + + } + return "", false +} + +func (p *mcuJanusPublisher) GetStreams(ctx context.Context) ([]PublisherStream, error) { + offerSdp := p.offerSdp.Load() + answerSdp := p.answerSdp.Load() + if offerSdp == nil || answerSdp == nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-p.sdpReady.C: + offerSdp = p.offerSdp.Load() + answerSdp = p.answerSdp.Load() + if offerSdp == nil || answerSdp == nil { + // Only can happen on invalid SDPs. + return nil, errors.New("no offer and/or answer processed yet") + } + } + } + + var streams []PublisherStream + for idx, m := range answerSdp.MediaDescriptions { + mid, found := m.Attribute(sdp.AttrKeyMID) + if !found { + continue + } + + s := PublisherStream{ + Mid: mid, + Mindex: idx, + Type: m.MediaName.Media, + } + + if len(m.MediaName.Formats) == 0 { + continue + } + + if strings.EqualFold(s.Type, "application") && strings.EqualFold(m.MediaName.Formats[0], "webrtc-datachannel") { + s.Type = "data" + streams = append(streams, s) + continue + } + + pt, err := strconv.ParseInt(m.MediaName.Formats[0], 10, 8) + if err != nil { + continue + } + + answerCodec, err := answerSdp.GetCodecForPayloadType(uint8(pt)) + if err != nil { + continue + } + + if strings.EqualFold(s.Type, "audio") { + s.Codec = answerCodec.Name + if value, found := getFmtpValue(answerCodec.Fmtp, "useinbandfec"); found && value == "1" { + s.Fec = true + } + if value, found := getFmtpValue(answerCodec.Fmtp, "usedtx"); found && value == "1" { + s.Dtx = true + } + if value, found := getFmtpValue(answerCodec.Fmtp, "stereo"); found && value == "1" { + s.Stereo = true + } + } else if strings.EqualFold(s.Type, "video") { + s.Codec = answerCodec.Name + // TODO: Determine if SVC is used. + s.Svc = false + + if strings.EqualFold(answerCodec.Name, "vp9") { + // Parse VP9 profile from "profile-id=XXX" + // Exampe: "a=fmtp:98 profile-id=0" + if profile, found := getFmtpValue(answerCodec.Fmtp, "profile-id"); found { + s.ProfileVP9 = profile + } + } else if strings.EqualFold(answerCodec.Name, "h264") { + // Parse H.264 profile from "profile-level-id=XXX" + // Example: "a=fmtp:104 level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42001f" + if profile, found := getFmtpValue(answerCodec.Fmtp, "profile-level-id"); found { + s.ProfileH264 = profile + } + } + + var extmap sdp.ExtMap + for _, a := range m.Attributes { + switch a.Key { + case sdp.AttrKeyExtMap: + if err := extmap.Unmarshal(extmap.Name() + ":" + a.Value); err != nil { + log.Printf("Error parsing extmap %s: %s", a.Value, err) + continue + } + + switch extmap.URI.String() { + case ExtensionUrlPlayoutDelay: + s.ExtIdPlayoutDelay = extmap.Value + case ExtensionUrlVideoOrientation: + s.ExtIdVideoOrientation = extmap.Value + } + case "simulcast": + s.Simulcast = true + case sdp.AttrKeySSRCGroup: + if strings.HasPrefix(a.Value, "SIM ") { + s.Simulcast = true + } + } + } + + for _, a := range offerSdp.MediaDescriptions[idx].Attributes { + switch a.Key { + case "simulcast": + s.Simulcast = true + case sdp.AttrKeySSRCGroup: + if strings.HasPrefix(a.Value, "SIM ") { + s.Simulcast = true + } + } + } + + } else if strings.EqualFold(s.Type, "data") { // nolint + // Already handled above. + } else { + log.Printf("Skip type %s", s.Type) + continue + } + + streams = append(streams, s) + } + + return streams, nil +} + +func getPublisherRemoteId(id string, remoteId string) string { + return fmt.Sprintf("%s@%s", id, remoteId) +} + +func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { + msg := map[string]interface{}{ + "request": "publish_remotely", + "room": p.roomId, + "publisher_id": streamTypeUserIds[p.streamType], + "remote_id": getPublisherRemoteId(p.id, remoteId), + "host": hostname, + "port": port, + "rtcp_port": rtcpPort, + } + response, err := p.handle.Request(ctx, msg) + if err != nil { + return err + } + + errorMessage := getPluginStringValue(response.PluginData, pluginVideoRoom, "error") + errorCode := getPluginIntValue(response.PluginData, pluginVideoRoom, "error_code") + if errorMessage != "" || errorCode != 0 { + if errorCode == 0 { + errorCode = 500 + } + if errorMessage == "" { + errorMessage = "unknown error" + } + + return &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: int(errorCode), + Reason: errorMessage, + }, + } + } + + log.Printf("Publishing %s to %s (port=%d, rtcpPort=%d) for %s", p.id, hostname, port, rtcpPort, remoteId) + return nil +} + +func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId string) error { + msg := map[string]interface{}{ + "request": "unpublish_remotely", + "room": p.roomId, + "publisher_id": streamTypeUserIds[p.streamType], + "remote_id": getPublisherRemoteId(p.id, remoteId), + } + response, err := p.handle.Request(ctx, msg) + if err != nil { + return err + } + + errorMessage := getPluginStringValue(response.PluginData, pluginVideoRoom, "error") + errorCode := getPluginIntValue(response.PluginData, pluginVideoRoom, "error_code") + if errorMessage != "" || errorCode != 0 { + if errorCode == 0 { + errorCode = 500 + } + if errorMessage == "" { + errorMessage = "unknown error" + } + + return &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: int(errorCode), + Reason: errorMessage, + }, + } + } + + log.Printf("Unpublished remote %s for %s", p.id, remoteId) + return nil +} diff --git a/mcu_janus_publisher_test.go b/mcu_janus_publisher_test.go new file mode 100644 index 0000000..dd81e79 --- /dev/null +++ b/mcu_janus_publisher_test.go @@ -0,0 +1,92 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "testing" +) + +func TestGetFmtpValueH264(t *testing.T) { + testcases := []struct { + fmtp string + profile string + }{ + { + "", + "", + }, + { + "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42001f", + "42001f", + }, + { + "level-asymmetry-allowed=1;packetization-mode=0", + "", + }, + { + "level-asymmetry-allowed=1; packetization-mode=0; profile-level-id = 42001f", + "42001f", + }, + } + + for _, tc := range testcases { + value, found := getFmtpValue(tc.fmtp, "profile-level-id") + if !found && tc.profile != "" { + t.Errorf("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp) + } else if found && tc.profile == "" { + t.Errorf("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value) + } else if found && tc.profile != value { + t.Errorf("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value) + } + } +} + +func TestGetFmtpValueVP9(t *testing.T) { + testcases := []struct { + fmtp string + profile string + }{ + { + "", + "", + }, + { + "profile-id=0", + "0", + }, + { + "profile-id = 0", + "0", + }, + } + + for _, tc := range testcases { + value, found := getFmtpValue(tc.fmtp, "profile-id") + if !found && tc.profile != "" { + t.Errorf("did not find profile \"%s\" in \"%s\"", tc.profile, tc.fmtp) + } else if found && tc.profile == "" { + t.Errorf("did not expect profile in \"%s\" but got \"%s\"", tc.fmtp, value) + } else if found && tc.profile != value { + t.Errorf("expected profile \"%s\" in \"%s\" but got \"%s\"", tc.profile, tc.fmtp, value) + } + } +} diff --git a/mcu_janus_remote_publisher.go b/mcu_janus_remote_publisher.go new file mode 100644 index 0000000..47593b0 --- /dev/null +++ b/mcu_janus_remote_publisher.go @@ -0,0 +1,150 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "log" + "sync/atomic" + + "github.com/notedit/janus-go" +) + +type mcuJanusRemotePublisher struct { + mcuJanusPublisher + + ref atomic.Int64 + + port int + rtcpPort int +} + +func (p *mcuJanusRemotePublisher) addRef() int64 { + return p.ref.Add(1) +} + +func (p *mcuJanusRemotePublisher) release() bool { + return p.ref.Add(-1) == 0 +} + +func (p *mcuJanusRemotePublisher) Port() int { + return p.port +} + +func (p *mcuJanusRemotePublisher) RtcpPort() int { + return p.rtcpPort +} + +func (p *mcuJanusRemotePublisher) handleEvent(event *janus.EventMsg) { + if videoroom := getPluginStringValue(event.Plugindata, pluginVideoRoom, "videoroom"); videoroom != "" { + ctx := context.TODO() + switch videoroom { + case "destroyed": + log.Printf("Remote publisher %d: associated room has been destroyed, closing", p.handleId) + go p.Close(ctx) + case "slow_link": + // Ignore, processed through "handleSlowLink" in the general events. + default: + log.Printf("Unsupported videoroom remote publisher event in %d: %+v", p.handleId, event) + } + } else { + log.Printf("Unsupported remote publisher event in %d: %+v", p.handleId, event) + } +} + +func (p *mcuJanusRemotePublisher) handleHangup(event *janus.HangupMsg) { + log.Printf("Remote publisher %d received hangup (%s), closing", p.handleId, event.Reason) + go p.Close(context.Background()) +} + +func (p *mcuJanusRemotePublisher) handleDetached(event *janus.DetachedMsg) { + log.Printf("Remote publisher %d received detached, closing", p.handleId) + go p.Close(context.Background()) +} + +func (p *mcuJanusRemotePublisher) handleConnected(event *janus.WebRTCUpMsg) { + log.Printf("Remote publisher %d received connected", p.handleId) + p.mcu.publisherConnected.Notify(getStreamId(p.id, p.streamType)) +} + +func (p *mcuJanusRemotePublisher) handleSlowLink(event *janus.SlowLinkMsg) { + if event.Uplink { + log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) + } else { + log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) + } +} + +func (p *mcuJanusRemotePublisher) NotifyReconnected() { + ctx := context.TODO() + handle, session, roomId, _, err := p.mcu.getOrCreatePublisherHandle(ctx, p.id, p.streamType, p.bitrate) + if err != nil { + log.Printf("Could not reconnect remote publisher %s: %s", p.id, err) + // TODO(jojo): Retry + return + } + + p.handle = handle + p.handleId = handle.Id + p.session = session + p.roomId = roomId + + log.Printf("Remote publisher %s reconnected on handle %d", p.id, p.handleId) +} + +func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { + if !p.release() { + return + } + + p.mu.Lock() + if handle := p.handle; handle != nil { + response, err := p.handle.Request(ctx, map[string]interface{}{ + "request": "remove_remote_publisher", + "room": p.roomId, + "id": streamTypeUserIds[p.streamType], + }) + if err != nil { + log.Printf("Error removing remote publisher %s in room %d: %s", p.id, p.roomId, err) + } else { + log.Printf("Removed remote publisher: %+v", response) + } + if p.roomId != 0 { + destroy_msg := map[string]interface{}{ + "request": "destroy", + "room": p.roomId, + } + if _, err := handle.Request(ctx, destroy_msg); err != nil { + log.Printf("Error destroying room %d: %s", p.roomId, err) + } else { + log.Printf("Room %d destroyed", p.roomId) + } + p.mcu.mu.Lock() + delete(p.mcu.remotePublishers, getStreamId(p.id, p.streamType)) + p.mcu.mu.Unlock() + p.roomId = 0 + } + } + + p.closeClient(ctx) + p.mu.Unlock() +} diff --git a/mcu_janus_remote_subscriber.go b/mcu_janus_remote_subscriber.go new file mode 100644 index 0000000..0900416 --- /dev/null +++ b/mcu_janus_remote_subscriber.go @@ -0,0 +1,115 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "log" + "strconv" + "sync/atomic" + + "github.com/notedit/janus-go" +) + +type mcuJanusRemoteSubscriber struct { + mcuJanusSubscriber + + remote atomic.Pointer[mcuJanusRemotePublisher] +} + +func (p *mcuJanusRemoteSubscriber) handleEvent(event *janus.EventMsg) { + if videoroom := getPluginStringValue(event.Plugindata, pluginVideoRoom, "videoroom"); videoroom != "" { + ctx := context.TODO() + switch videoroom { + case "destroyed": + log.Printf("Remote subscriber %d: associated room has been destroyed, closing", p.handleId) + go p.Close(ctx) + case "event": + // Handle renegotiations, but ignore other events like selected + // substream / temporal layer. + if getPluginStringValue(event.Plugindata, pluginVideoRoom, "configured") == "ok" && + event.Jsep != nil && event.Jsep["type"] == "offer" && event.Jsep["sdp"] != nil { + p.listener.OnUpdateOffer(p, event.Jsep) + } + case "slow_link": + // Ignore, processed through "handleSlowLink" in the general events. + default: + log.Printf("Unsupported videoroom event %s for remote subscriber %d: %+v", videoroom, p.handleId, event) + } + } else { + log.Printf("Unsupported event for remote subscriber %d: %+v", p.handleId, event) + } +} + +func (p *mcuJanusRemoteSubscriber) handleHangup(event *janus.HangupMsg) { + log.Printf("Remote subscriber %d received hangup (%s), closing", p.handleId, event.Reason) + go p.Close(context.Background()) +} + +func (p *mcuJanusRemoteSubscriber) handleDetached(event *janus.DetachedMsg) { + log.Printf("Remote subscriber %d received detached, closing", p.handleId) + go p.Close(context.Background()) +} + +func (p *mcuJanusRemoteSubscriber) handleConnected(event *janus.WebRTCUpMsg) { + log.Printf("Remote subscriber %d received connected", p.handleId) + p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType) +} + +func (p *mcuJanusRemoteSubscriber) handleSlowLink(event *janus.SlowLinkMsg) { + if event.Uplink { + log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) + } else { + log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) + } +} + +func (p *mcuJanusRemoteSubscriber) handleMedia(event *janus.MediaMsg) { + // Only triggered for publishers +} + +func (p *mcuJanusRemoteSubscriber) NotifyReconnected() { + ctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) + defer cancel() + handle, pub, err := p.mcu.getOrCreateSubscriberHandle(ctx, p.publisher, p.streamType) + if err != nil { + // TODO(jojo): Retry? + log.Printf("Could not reconnect remote subscriber for publisher %s: %s", p.publisher, err) + p.Close(context.Background()) + return + } + + p.handle = handle + p.handleId = handle.Id + p.roomId = pub.roomId + p.sid = strconv.FormatUint(handle.Id, 10) + p.listener.SubscriberSidUpdated(p) + log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId) +} + +func (p *mcuJanusRemoteSubscriber) Close(ctx context.Context) { + p.mcuJanusSubscriber.Close(ctx) + + if remote := p.remote.Swap(nil); remote != nil { + remote.Close(context.Background()) + } +} diff --git a/mcu_janus_stream_selection.go b/mcu_janus_stream_selection.go new file mode 100644 index 0000000..9381ef3 --- /dev/null +++ b/mcu_janus_stream_selection.go @@ -0,0 +1,110 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2017 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "database/sql" + "fmt" +) + +type streamSelection struct { + substream sql.NullInt16 + temporal sql.NullInt16 + audio sql.NullBool + video sql.NullBool +} + +func (s *streamSelection) HasValues() bool { + return s.substream.Valid || s.temporal.Valid || s.audio.Valid || s.video.Valid +} + +func (s *streamSelection) AddToMessage(message map[string]interface{}) { + if s.substream.Valid { + message["substream"] = s.substream.Int16 + } + if s.temporal.Valid { + message["temporal"] = s.temporal.Int16 + } + if s.audio.Valid { + message["audio"] = s.audio.Bool + } + if s.video.Valid { + message["video"] = s.video.Bool + } +} + +func parseStreamSelection(payload map[string]interface{}) (*streamSelection, error) { + var stream streamSelection + if value, found := payload["substream"]; found { + switch value := value.(type) { + case int: + stream.substream.Valid = true + stream.substream.Int16 = int16(value) + case float32: + stream.substream.Valid = true + stream.substream.Int16 = int16(value) + case float64: + stream.substream.Valid = true + stream.substream.Int16 = int16(value) + default: + return nil, fmt.Errorf("Unsupported substream value: %v", value) + } + } + + if value, found := payload["temporal"]; found { + switch value := value.(type) { + case int: + stream.temporal.Valid = true + stream.temporal.Int16 = int16(value) + case float32: + stream.temporal.Valid = true + stream.temporal.Int16 = int16(value) + case float64: + stream.temporal.Valid = true + stream.temporal.Int16 = int16(value) + default: + return nil, fmt.Errorf("Unsupported temporal value: %v", value) + } + } + + if value, found := payload["audio"]; found { + switch value := value.(type) { + case bool: + stream.audio.Valid = true + stream.audio.Bool = value + default: + return nil, fmt.Errorf("Unsupported audio value: %v", value) + } + } + + if value, found := payload["video"]; found { + switch value := value.(type) { + case bool: + stream.video.Valid = true + stream.video.Bool = value + default: + return nil, fmt.Errorf("Unsupported video value: %v", value) + } + } + + return &stream, nil +} diff --git a/mcu_janus_subscriber.go b/mcu_janus_subscriber.go new file mode 100644 index 0000000..b63f4e9 --- /dev/null +++ b/mcu_janus_subscriber.go @@ -0,0 +1,321 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2017 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "fmt" + "log" + "strconv" + + "github.com/notedit/janus-go" +) + +type mcuJanusSubscriber struct { + mcuJanusClient + + publisher string +} + +func (p *mcuJanusSubscriber) Publisher() string { + return p.publisher +} + +func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) { + if videoroom := getPluginStringValue(event.Plugindata, pluginVideoRoom, "videoroom"); videoroom != "" { + ctx := context.TODO() + switch videoroom { + case "destroyed": + log.Printf("Subscriber %d: associated room has been destroyed, closing", p.handleId) + go p.Close(ctx) + case "event": + // Handle renegotiations, but ignore other events like selected + // substream / temporal layer. + if getPluginStringValue(event.Plugindata, pluginVideoRoom, "configured") == "ok" && + event.Jsep != nil && event.Jsep["type"] == "offer" && event.Jsep["sdp"] != nil { + p.listener.OnUpdateOffer(p, event.Jsep) + } + case "slow_link": + // Ignore, processed through "handleSlowLink" in the general events. + default: + log.Printf("Unsupported videoroom event %s for subscriber %d: %+v", videoroom, p.handleId, event) + } + } else { + log.Printf("Unsupported event for subscriber %d: %+v", p.handleId, event) + } +} + +func (p *mcuJanusSubscriber) handleHangup(event *janus.HangupMsg) { + log.Printf("Subscriber %d received hangup (%s), closing", p.handleId, event.Reason) + go p.Close(context.Background()) +} + +func (p *mcuJanusSubscriber) handleDetached(event *janus.DetachedMsg) { + log.Printf("Subscriber %d received detached, closing", p.handleId) + go p.Close(context.Background()) +} + +func (p *mcuJanusSubscriber) handleConnected(event *janus.WebRTCUpMsg) { + log.Printf("Subscriber %d received connected", p.handleId) + p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType) +} + +func (p *mcuJanusSubscriber) handleSlowLink(event *janus.SlowLinkMsg) { + if event.Uplink { + log.Printf("Subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) + } else { + log.Printf("Subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) + } +} + +func (p *mcuJanusSubscriber) handleMedia(event *janus.MediaMsg) { + // Only triggered for publishers +} + +func (p *mcuJanusSubscriber) NotifyReconnected() { + ctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) + defer cancel() + handle, pub, err := p.mcu.getOrCreateSubscriberHandle(ctx, p.publisher, p.streamType) + if err != nil { + // TODO(jojo): Retry? + log.Printf("Could not reconnect subscriber for publisher %s: %s", p.publisher, err) + p.Close(context.Background()) + return + } + + p.handle = handle + p.handleId = handle.Id + p.roomId = pub.roomId + p.sid = strconv.FormatUint(handle.Id, 10) + p.listener.SubscriberSidUpdated(p) + log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId) +} + +func (p *mcuJanusSubscriber) Close(ctx context.Context) { + p.mu.Lock() + closed := p.closeClient(ctx) + p.mu.Unlock() + + if closed { + p.mcu.SubscriberDisconnected(p.Id(), p.publisher, p.streamType) + statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Dec() + } + p.mcu.unregisterClient(p) + p.listener.SubscriberClosed(p) + p.mcuJanusClient.Close(ctx) +} + +func (p *mcuJanusSubscriber) joinRoom(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { + handle := p.handle + if handle == nil { + callback(ErrNotConnected, nil) + return + } + + waiter := p.mcu.publisherConnected.NewWaiter(getStreamId(p.publisher, p.streamType)) + defer p.mcu.publisherConnected.Release(waiter) + + loggedNotPublishingYet := false +retry: + join_msg := map[string]interface{}{ + "request": "join", + "ptype": "subscriber", + "room": p.roomId, + } + if p.mcu.isMultistream() { + join_msg["streams"] = []map[string]interface{}{ + { + "feed": streamTypeUserIds[p.streamType], + }, + } + } else { + join_msg["feed"] = streamTypeUserIds[p.streamType] + } + if stream != nil { + stream.AddToMessage(join_msg) + } + join_response, err := handle.Message(ctx, join_msg, nil) + if err != nil { + callback(err, nil) + return + } + + if error_code := getPluginIntValue(join_response.Plugindata, pluginVideoRoom, "error_code"); error_code > 0 { + switch error_code { + case JANUS_VIDEOROOM_ERROR_ALREADY_JOINED: + // The subscriber is already connected to the room. This can happen + // if a client leaves a call but keeps the subscriber objects active. + // On joining the call again, the subscriber tries to join on the + // MCU which will fail because he is still connected. + // To get a new Offer SDP, we have to tear down the session on the + // MCU and join again. + p.mu.Lock() + p.closeClient(ctx) + p.mu.Unlock() + + var pub *mcuJanusPublisher + handle, pub, err = p.mcu.getOrCreateSubscriberHandle(ctx, p.publisher, p.streamType) + if err != nil { + // Reconnection didn't work, need to unregister/remove subscriber + // so a new object will be created if the request is retried. + p.mcu.unregisterClient(p) + p.listener.SubscriberClosed(p) + callback(fmt.Errorf("Already connected as subscriber for %s, error during re-joining: %s", p.streamType, err), nil) + return + } + + p.handle = handle + p.handleId = handle.Id + p.roomId = pub.roomId + p.sid = strconv.FormatUint(handle.Id, 10) + p.listener.SubscriberSidUpdated(p) + p.closeChan = make(chan struct{}, 1) + go p.run(p.handle, p.closeChan) + log.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId) + goto retry + case JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM: + fallthrough + case JANUS_VIDEOROOM_ERROR_NO_SUCH_FEED: + switch error_code { + case JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM: + log.Printf("Publisher %s not created yet for %s, wait and retry to join room %d as subscriber", p.publisher, p.streamType, p.roomId) + case JANUS_VIDEOROOM_ERROR_NO_SUCH_FEED: + log.Printf("Publisher %s not sending yet for %s, wait and retry to join room %d as subscriber", p.publisher, p.streamType, p.roomId) + } + + if !loggedNotPublishingYet { + loggedNotPublishingYet = true + statsWaitingForPublisherTotal.WithLabelValues(string(p.streamType)).Inc() + } + + if err := waiter.Wait(ctx); err != nil { + callback(err, nil) + return + } + log.Printf("Retry subscribing %s from %s", p.streamType, p.publisher) + goto retry + default: + // TODO(jojo): Should we handle other errors, too? + callback(fmt.Errorf("Error joining room as subscriber: %+v", join_response), nil) + return + } + } + //log.Println("Joined as listener", join_response) + + p.session = join_response.Session + callback(nil, join_response.Jsep) +} + +func (p *mcuJanusSubscriber) update(ctx context.Context, stream *streamSelection, callback func(error, map[string]interface{})) { + handle := p.handle + if handle == nil { + callback(ErrNotConnected, nil) + return + } + + configure_msg := map[string]interface{}{ + "request": "configure", + "update": true, + } + if stream != nil { + stream.AddToMessage(configure_msg) + } + configure_response, err := handle.Message(ctx, configure_msg, nil) + if err != nil { + callback(err, nil) + return + } + + callback(nil, configure_response.Jsep) +} + +func (p *mcuJanusSubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { + statsMcuMessagesTotal.WithLabelValues(data.Type).Inc() + jsep_msg := data.Payload + switch data.Type { + case "requestoffer": + fallthrough + case "sendoffer": + p.deferred <- func() { + msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) + defer cancel() + + stream, err := parseStreamSelection(jsep_msg) + if err != nil { + go callback(err, nil) + return + } + + if data.Sid == "" || data.Sid != p.Sid() { + p.joinRoom(msgctx, stream, callback) + } else { + p.update(msgctx, stream, callback) + } + } + case "answer": + p.deferred <- func() { + msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) + defer cancel() + + if data.Sid == "" || data.Sid == p.Sid() { + p.sendAnswer(msgctx, jsep_msg, callback) + } else { + go callback(fmt.Errorf("Answer message sid (%s) does not match subscriber sid (%s)", data.Sid, p.Sid()), nil) + } + } + case "candidate": + p.deferred <- func() { + msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) + defer cancel() + + if data.Sid == "" || data.Sid == p.Sid() { + p.sendCandidate(msgctx, jsep_msg["candidate"], callback) + } else { + go callback(fmt.Errorf("Candidate message sid (%s) does not match subscriber sid (%s)", data.Sid, p.Sid()), nil) + } + } + case "endOfCandidates": + // Ignore + case "selectStream": + stream, err := parseStreamSelection(jsep_msg) + if err != nil { + go callback(err, nil) + return + } + + if stream == nil || !stream.HasValues() { + // Nothing to do + go callback(nil, nil) + return + } + + p.deferred <- func() { + msgctx, cancel := context.WithTimeout(context.Background(), p.mcu.mcuTimeout) + defer cancel() + + p.selectStream(msgctx, stream, callback) + } + default: + // Return error asynchronously + go callback(fmt.Errorf("Unsupported message type: %s", data.Type), nil) + } +} diff --git a/mcu_proxy.go b/mcu_proxy.go index 31a3191..5b34426 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -218,13 +218,26 @@ func (p *mcuProxyPublisher) ProcessEvent(msg *EventProxyServerMessage) { } } +func (p *mcuProxyPublisher) GetStreams(ctx context.Context) ([]PublisherStream, error) { + return nil, errors.New("not implemented") +} + +func (p *mcuProxyPublisher) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { + return errors.New("remote publishing not supported for proxy publishers") +} + +func (p *mcuProxyPublisher) UnpublishRemote(ctx context.Context, remoteId string) error { + return errors.New("remote publishing not supported for proxy publishers") +} + type mcuProxySubscriber struct { mcuProxyPubSubCommon - publisherId string + publisherId string + publisherConn *mcuProxyConnection } -func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType, maxBitrate int, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber { +func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType, maxBitrate int, proxyId string, conn *mcuProxyConnection, listener McuListener, publisherConn *mcuProxyConnection) *mcuProxySubscriber { return &mcuProxySubscriber{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{ sid: sid, @@ -235,7 +248,8 @@ func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType listener: listener, }, - publisherId: publisherId, + publisherId: publisherId, + publisherConn: publisherConn, } } @@ -244,7 +258,11 @@ func (s *mcuProxySubscriber) Publisher() string { } func (s *mcuProxySubscriber) NotifyClosed() { - log.Printf("Subscriber %s at %s was closed", s.proxyId, s.conn) + if s.publisherConn != nil { + log.Printf("Remote subscriber %s at %s (forwarded to %s) was closed", s.proxyId, s.conn, s.publisherConn) + } else { + log.Printf("Subscriber %s at %s was closed", s.proxyId, s.conn) + } s.listener.SubscriberClosed(s) s.conn.removeSubscriber(s) } @@ -261,14 +279,26 @@ func (s *mcuProxySubscriber) Close(ctx context.Context) { } if response, err := s.conn.performSyncRequest(ctx, msg); err != nil { - log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, err) + if s.publisherConn != nil { + log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, err) + } else { + log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, err) + } return } else if response.Type == "error" { - log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, response.Error) + if s.publisherConn != nil { + log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, response.Error) + } else { + log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, response.Error) + } return } - log.Printf("Deleted subscriber %s at %s", s.proxyId, s.conn) + if s.publisherConn != nil { + log.Printf("Deleted remote subscriber %s at %s (forwarded to %s)", s.proxyId, s.conn, s.publisherConn) + } else { + log.Printf("Deleted subscriber %s at %s", s.proxyId, s.conn) + } } func (s *mcuProxySubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { @@ -310,6 +340,7 @@ type mcuProxyConnection struct { ip net.IP load atomic.Int64 + bandwidth atomic.Pointer[EventProxyServerBandwidth] mu sync.Mutex closer *Closer closedDone *Closer @@ -328,7 +359,7 @@ type mcuProxyConnection struct { msgId atomic.Int64 helloMsgId string - sessionId string + sessionId atomic.Value country atomic.Value callbacks map[string]func(*ProxyServerMessage) @@ -361,6 +392,7 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP) (*mcuProx } conn.reconnectInterval.Store(int64(initialReconnectInterval)) conn.load.Store(loadNotConnected) + conn.bandwidth.Store(nil) conn.country.Store("") return conn, nil } @@ -373,6 +405,54 @@ func (c *mcuProxyConnection) String() string { return c.rawUrl } +func (c *mcuProxyConnection) IsSameCountry(initiator McuInitiator) bool { + if initiator == nil { + return true + } + + initiatorCountry := initiator.Country() + if initiatorCountry == "" { + return true + } + + connCountry := c.Country() + if connCountry == "" { + return true + } + + return initiatorCountry == connCountry +} + +func (c *mcuProxyConnection) IsSameContinent(initiator McuInitiator) bool { + if initiator == nil { + return true + } + + initiatorCountry := initiator.Country() + if initiatorCountry == "" { + return true + } + + connCountry := c.Country() + if connCountry == "" { + return true + } + + initiatorContinents, found := ContinentMap[initiatorCountry] + if found { + m := c.proxy.getContinentsMap() + // Map continents to other continents (e.g. use Europe for Africa). + for _, continent := range initiatorContinents { + if toAdd, found := m[continent]; found { + initiatorContinents = append(initiatorContinents, toAdd...) + } + } + + } + connContinents := ContinentMap[connCountry] + return ContinentsOverlap(initiatorContinents, connContinents) +} + type mcuProxyConnectionStats struct { Url string `json:"url"` IP net.IP `json:"ip,omitempty"` @@ -416,10 +496,29 @@ func (c *mcuProxyConnection) Load() int64 { return c.load.Load() } +func (c *mcuProxyConnection) Bandwidth() *EventProxyServerBandwidth { + return c.bandwidth.Load() +} + func (c *mcuProxyConnection) Country() string { return c.country.Load().(string) } +func (c *mcuProxyConnection) SessionId() string { + sid := c.sessionId.Load() + if sid == nil { + return "" + } + + return sid.(string) +} + +func (c *mcuProxyConnection) IsConnected() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn != nil && c.SessionId() != "" +} + func (c *mcuProxyConnection) IsTemporary() bool { return c.temporary.Load() } @@ -445,7 +544,10 @@ func (c *mcuProxyConnection) readPump() { } }() defer c.close() - defer c.load.Store(loadNotConnected) + defer func() { + c.load.Store(loadNotConnected) + c.bandwidth.Store(nil) + }() c.mu.Lock() conn := c.conn @@ -810,11 +912,11 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { switch msg.Type { case "error": if msg.Error.Code == "no_such_session" { - log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c) + log.Printf("Session %s could not be resumed on %s, registering new", c.SessionId(), c) c.clearPublishers() c.clearSubscribers() c.clearCallbacks() - c.sessionId = "" + c.sessionId.Store("") if err := c.sendHello(); err != nil { log.Printf("Could not send hello request to %s: %s", c, err) c.scheduleReconnect() @@ -825,8 +927,8 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) c.scheduleReconnect() case "hello": - resumed := c.sessionId == msg.Hello.SessionId - c.sessionId = msg.Hello.SessionId + resumed := c.SessionId() == msg.Hello.SessionId + c.sessionId.Store(msg.Hello.SessionId) country := "" if msg.Hello.Server != nil { if country = msg.Hello.Server.Country; country != "" && !IsValidCountry(country) { @@ -836,11 +938,11 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { } c.country.Store(country) if resumed { - log.Printf("Resumed session %s on %s", c.sessionId, c) + log.Printf("Resumed session %s on %s", c.SessionId(), c) } else if country != "" { - log.Printf("Received session %s from %s (in %s)", c.sessionId, c, country) + log.Printf("Received session %s from %s (in %s)", c.SessionId(), c, country) } else { - log.Printf("Received session %s from %s", c.sessionId, c) + log.Printf("Received session %s from %s", c.SessionId(), c) } if c.trackClose.CompareAndSwap(false, true) { statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Inc() @@ -911,9 +1013,10 @@ func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) { return case "update-load": if proxyDebugMessages { - log.Printf("Load of %s now at %d", c, event.Load) + log.Printf("Load of %s now at %d (%s)", c, event.Load, event.Bandwidth) } c.load.Store(event.Load) + c.bandwidth.Store(event.Bandwidth) statsProxyBackendLoadCurrent.WithLabelValues(c.url.String()).Set(float64(event.Load)) return case "shutdown-scheduled": @@ -948,8 +1051,8 @@ func (c *mcuProxyConnection) processBye(msg *ProxyServerMessage) { bye := msg.Bye switch bye.Reason { case "session_resumed": - log.Printf("Session %s on %s was resumed by other client, resetting", c.sessionId, c) - c.sessionId = "" + log.Printf("Session %s on %s was resumed by other client, resetting", c.SessionId(), c) + c.sessionId.Store("") default: log.Printf("Received bye with unsupported reason from %s %+v", c, bye) } @@ -964,17 +1067,10 @@ func (c *mcuProxyConnection) sendHello() error { Version: "1.0", }, } - if c.sessionId != "" { - msg.Hello.ResumeId = c.sessionId + if sessionId := c.SessionId(); sessionId != "" { + msg.Hello.ResumeId = sessionId } else { - claims := &TokenClaims{ - jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: c.proxy.tokenId, - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - tokenString, err := token.SignedString(c.proxy.tokenKey) + tokenString, err := c.proxy.createToken("") if err != nil { return err } @@ -1095,7 +1191,48 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList proxyId := response.Command.Id log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c, publisherSessionId) - subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener) + subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, nil) + c.subscribersLock.Lock() + c.subscribers[proxyId] = subscriber + c.subscribersLock.Unlock() + statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc() + statsSubscribersTotal.WithLabelValues(string(streamType)).Inc() + return subscriber, nil +} + +func (c *mcuProxyConnection) newRemoteSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType StreamType, publisherConn *mcuProxyConnection) (McuSubscriber, error) { + if c == publisherConn { + return c.newSubscriber(ctx, listener, publisherId, publisherSessionId, streamType) + } + + remoteToken, err := c.proxy.createToken(publisherId) + if err != nil { + return nil, err + } + + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "create-subscriber", + StreamType: streamType, + PublisherId: publisherId, + + RemoteUrl: publisherConn.rawUrl, + RemoteToken: remoteToken, + }, + } + + response, err := c.performSyncRequest(ctx, msg) + if err != nil { + // TODO: Cancel request + return nil, err + } else if response.Type == "error" { + return nil, fmt.Errorf("Error creating remote %s subscriber for %s on %s (forwarded to %s): %+v", streamType, publisherSessionId, c, publisherConn, response.Error) + } + + proxyId := response.Command.Id + log.Printf("Created remote %s subscriber %s on %s for %s (forwarded to %s)", streamType, proxyId, c, publisherSessionId, publisherConn) + subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, publisherConn) c.subscribersLock.Lock() c.subscribers[proxyId] = subscriber c.subscribersLock.Unlock() @@ -1258,7 +1395,7 @@ func (m *mcuProxy) loadContinentsMap(config *goconf.ConfigFile) error { return nil } -func (m *mcuProxy) Start() error { +func (m *mcuProxy) Start(ctx context.Context) error { log.Printf("Maximum bandwidth %d bits/sec per publishing stream", m.maxStreamBitrate) log.Printf("Maximum bandwidth %d bits/sec per screensharing stream", m.maxScreenBitrate) @@ -1278,6 +1415,48 @@ func (m *mcuProxy) Stop() { m.config.Stop() } +func (m *mcuProxy) createToken(subject string) (string, error) { + claims := &TokenClaims{ + jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: m.tokenId, + Subject: subject, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(m.tokenKey) + if err != nil { + return "", err + } + + return tokenString, nil +} + +func (m *mcuProxy) hasConnections() bool { + m.connectionsMu.RLock() + defer m.connectionsMu.RUnlock() + for _, conn := range m.connections { + if conn.IsConnected() { + return true + } + } + return false +} + +func (m *mcuProxy) WaitForConnections(ctx context.Context) error { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for !m.hasConnections() { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } + return nil +} + func (m *mcuProxy) AddConnection(ignoreErrors bool, url string, ips ...net.IP) error { m.connectionsMu.Lock() defer m.connectionsMu.Unlock() @@ -1569,27 +1748,27 @@ func (m *mcuProxy) removePublisher(publisher *mcuProxyPublisher) { delete(m.publishers, getStreamId(publisher.id, publisher.StreamType())) } -func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { - connections := m.getSortedConnections(initiator) +func (m *mcuProxy) createPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator, connections []*mcuProxyConnection, isAllowed func(c *mcuProxyConnection) bool) McuPublisher { + var maxBitrate int + if streamType == StreamTypeScreen { + maxBitrate = m.maxScreenBitrate + } else { + maxBitrate = m.maxStreamBitrate + } + if bitrate <= 0 { + bitrate = maxBitrate + } else { + bitrate = min(bitrate, maxBitrate) + } + for _, conn := range connections { - if conn.IsShutdownScheduled() || conn.IsTemporary() { + if !isAllowed(conn) || conn.IsShutdownScheduled() || conn.IsTemporary() { continue } subctx, cancel := context.WithTimeout(ctx, m.proxyTimeout) defer cancel() - var maxBitrate int - if streamType == StreamTypeScreen { - maxBitrate = m.maxScreenBitrate - } else { - maxBitrate = m.maxStreamBitrate - } - if bitrate <= 0 { - bitrate = maxBitrate - } else { - bitrate = min(bitrate, maxBitrate) - } publisher, err := conn.newPublisher(subctx, listener, id, sid, streamType, bitrate, mediaTypes) if err != nil { log.Printf("Could not create %s publisher for %s on %s: %s", streamType, id, conn, err) @@ -1600,11 +1779,61 @@ func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id st m.publishers[getStreamId(id, streamType)] = conn m.mu.Unlock() m.publisherWaiters.Wakeup() - return publisher, nil + return publisher } - statsProxyNobackendAvailableTotal.WithLabelValues(string(streamType)).Inc() - return nil, fmt.Errorf("No MCU connection available") + return nil +} + +func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { + connections := m.getSortedConnections(initiator) + publisher := m.createPublisher(ctx, listener, id, sid, streamType, bitrate, mediaTypes, initiator, connections, func(c *mcuProxyConnection) bool { + bw := c.Bandwidth() + return bw == nil || bw.AllowIncoming() + }) + if publisher == nil { + // No proxy has available bandwidth, select one with the lowest currently used bandwidth. + connections2 := make([]*mcuProxyConnection, 0, len(connections)) + for _, c := range connections { + if c.Bandwidth() != nil { + connections2 = append(connections2, c) + } + } + SlicesSortFunc(connections2, func(a *mcuProxyConnection, b *mcuProxyConnection) int { + var incoming_a *float64 + if bw := a.Bandwidth(); bw != nil { + incoming_a = bw.Incoming + } + + var incoming_b *float64 + if bw := b.Bandwidth(); bw != nil { + incoming_b = bw.Incoming + } + + if incoming_a == nil && incoming_b == nil { + return 0 + } else if incoming_a == nil && incoming_b != nil { + return -1 + } else if incoming_a != nil && incoming_b == nil { + return -1 + } else if *incoming_a < *incoming_b { + return -1 + } else if *incoming_a > *incoming_b { + return 1 + } + return 0 + }) + publisher = m.createPublisher(ctx, listener, id, sid, streamType, bitrate, mediaTypes, initiator, connections2, func(c *mcuProxyConnection) bool { + return true + }) + } + + if publisher == nil { + statsProxyNobackendAvailableTotal.WithLabelValues(string(streamType)).Inc() + return nil, fmt.Errorf("No MCU connection available") + } + + return publisher, nil } func (m *mcuProxy) getPublisherConnection(publisher string, streamType StreamType) *mcuProxyConnection { @@ -1645,7 +1874,38 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str } } -func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { +type proxyPublisherInfo struct { + id string + conn *mcuProxyConnection + err error +} + +func (m *mcuProxy) createSubscriber(ctx context.Context, listener McuListener, id string, publisher string, streamType StreamType, publisherConn *mcuProxyConnection, connections []*mcuProxyConnection, isAllowed func(c *mcuProxyConnection) bool) McuSubscriber { + for _, conn := range connections { + if !isAllowed(conn) || conn.IsShutdownScheduled() || conn.IsTemporary() { + continue + } + + var subscriber McuSubscriber + var err error + if conn == publisherConn { + subscriber, err = conn.newSubscriber(ctx, listener, id, publisher, streamType) + } else { + subscriber, err = conn.newRemoteSubscriber(ctx, listener, id, publisher, streamType, publisherConn) + } + if err != nil { + log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, conn, err) + continue + } + + return subscriber + } + + return nil +} + +func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) { + var publisherInfo *proxyPublisherInfo if conn := m.getPublisherConnection(publisher, streamType); conn != nil { // Fast common path: publisher is available locally. conn.publishersLock.Lock() @@ -1655,113 +1915,190 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ return nil, fmt.Errorf("Unknown publisher %s", publisher) } - return conn.newSubscriber(ctx, listener, id, publisher, streamType) - } - - log.Printf("No %s publisher %s found yet, deferring", streamType, publisher) - ch := make(chan McuSubscriber) - getctx, cancel := context.WithCancel(ctx) - defer cancel() - - // Wait for publisher to be created locally. - go func() { - if conn := m.waitForPublisherConnection(getctx, publisher, streamType); conn != nil { - cancel() // Cancel pending RPC calls. - - conn.publishersLock.Lock() - id, found := conn.publisherIds[getStreamId(publisher, streamType)] - conn.publishersLock.Unlock() - if !found { - log.Printf("Unknown id for local %s publisher %s", streamType, publisher) - return - } - - subscriber, err := conn.newSubscriber(ctx, listener, id, publisher, streamType) - if subscriber != nil { - ch <- subscriber - } else if err != nil { - log.Printf("Error creating local subscriber for %s publisher %s: %s", streamType, publisher, err) - } + publisherInfo = &proxyPublisherInfo{ + id: id, + conn: conn, } - }() + } else { + log.Printf("No %s publisher %s found yet, deferring", streamType, publisher) + ch := make(chan *proxyPublisherInfo, 1) + getctx, cancel := context.WithCancel(ctx) + defer cancel() - // Wait for publisher to be created on one of the other servers in the cluster. - if clients := m.rpcClients.GetClients(); len(clients) > 0 { - for _, client := range clients { - go func(client *GrpcClient) { - id, url, ip, err := client.GetPublisherId(getctx, publisher, streamType) - if errors.Is(err, context.Canceled) { - return - } else if err != nil { - log.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err) - return - } else if id == "" { - // Publisher not found on other server - return - } + var wg sync.WaitGroup + // Wait for publisher to be created locally. + wg.Add(1) + go func() { + defer wg.Done() + if conn := m.waitForPublisherConnection(getctx, publisher, streamType); conn != nil { cancel() // Cancel pending RPC calls. - log.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url) - m.connectionsMu.RLock() - connections := m.connections - m.connectionsMu.RUnlock() - var publisherConn *mcuProxyConnection - for _, conn := range connections { - if conn.rawUrl != url || !ip.Equal(conn.ip) { - continue + conn.publishersLock.Lock() + id, found := conn.publisherIds[getStreamId(publisher, streamType)] + conn.publishersLock.Unlock() + if !found { + ch <- &proxyPublisherInfo{ + err: fmt.Errorf("Unknown id for local %s publisher %s", streamType, publisher), } - - // Simple case, signaling server has a connection to the same endpoint - publisherConn = conn - break - } - - if publisherConn == nil { - publisherConn, err = newMcuProxyConnection(m, url, ip) - if err != nil { - log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err) - return - } - - publisherConn.setTemporary() - publisherConn.start() - if err := publisherConn.waitUntilConnected(ctx); err != nil { - log.Printf("Could not establish new connection to %s: %s", publisherConn, err) - publisherConn.closeIfEmpty() - return - } - - m.connectionsMu.Lock() - m.connections = append(m.connections, publisherConn) - conns, found := m.connectionsMap[url] - if found { - conns = append(conns, publisherConn) - } else { - conns = []*mcuProxyConnection{publisherConn} - } - m.connectionsMap[url] = conns - m.connectionsMu.Unlock() - } - - subscriber, err := publisherConn.newSubscriber(ctx, listener, id, publisher, streamType) - if err != nil { - if publisherConn.IsTemporary() { - publisherConn.closeIfEmpty() - } - log.Printf("Could not create subscriber for %s publisher %s: %s", streamType, publisher, err) return } - ch <- subscriber - }(client) + ch <- &proxyPublisherInfo{ + id: id, + conn: conn, + } + } + }() + + // Wait for publisher to be created on one of the other servers in the cluster. + if clients := m.rpcClients.GetClients(); len(clients) > 0 { + for _, client := range clients { + wg.Add(1) + go func(client *GrpcClient) { + defer wg.Done() + id, url, ip, err := client.GetPublisherId(getctx, publisher, streamType) + if errors.Is(err, context.Canceled) { + return + } else if err != nil { + log.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err) + return + } else if id == "" { + // Publisher not found on other server + return + } + + cancel() // Cancel pending RPC calls. + log.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url) + + m.connectionsMu.RLock() + connections := m.connections + m.connectionsMu.RUnlock() + var publisherConn *mcuProxyConnection + for _, conn := range connections { + if conn.rawUrl != url || !ip.Equal(conn.ip) { + continue + } + + // Simple case, signaling server has a connection to the same endpoint + publisherConn = conn + break + } + + if publisherConn == nil { + publisherConn, err = newMcuProxyConnection(m, url, ip) + if err != nil { + log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err) + return + } + + publisherConn.setTemporary() + publisherConn.start() + if err := publisherConn.waitUntilConnected(ctx); err != nil { + log.Printf("Could not establish new connection to %s: %s", publisherConn, err) + publisherConn.closeIfEmpty() + return + } + + m.connectionsMu.Lock() + m.connections = append(m.connections, publisherConn) + conns, found := m.connectionsMap[url] + if found { + conns = append(conns, publisherConn) + } else { + conns = []*mcuProxyConnection{publisherConn} + } + m.connectionsMap[url] = conns + m.connectionsMu.Unlock() + } + + ch <- &proxyPublisherInfo{ + id: id, + conn: publisherConn, + } + }(client) + } + } + + wg.Wait() + select { + case ch <- &proxyPublisherInfo{ + err: fmt.Errorf("No %s publisher %s found", streamType, publisher), + }: + default: + } + + select { + case info := <-ch: + publisherInfo = info + case <-ctx.Done(): + return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher) } } - select { - case subscriber := <-ch: - return subscriber, nil - case <-ctx.Done(): - return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher) + if publisherInfo.err != nil { + return nil, publisherInfo.err } + + bw := publisherInfo.conn.Bandwidth() + allowOutgoing := bw == nil || bw.AllowOutgoing() + if !allowOutgoing || !publisherInfo.conn.IsSameCountry(initiator) { + connections := m.getSortedConnections(initiator) + if !allowOutgoing || len(connections) > 0 && !connections[0].IsSameCountry(publisherInfo.conn) { + // Connect to remote publisher through "closer" gateway. + subscriber := m.createSubscriber(ctx, listener, publisherInfo.id, publisher, streamType, publisherInfo.conn, connections, func(c *mcuProxyConnection) bool { + bw := c.Bandwidth() + return bw == nil || bw.AllowOutgoing() + }) + if subscriber == nil { + connections2 := make([]*mcuProxyConnection, 0, len(connections)) + for _, c := range connections { + if c.Bandwidth() != nil { + connections2 = append(connections2, c) + } + } + SlicesSortFunc(connections2, func(a *mcuProxyConnection, b *mcuProxyConnection) int { + var outgoing_a *float64 + if bw := a.Bandwidth(); bw != nil { + outgoing_a = bw.Outgoing + } + + var outgoing_b *float64 + if bw := b.Bandwidth(); bw != nil { + outgoing_b = bw.Outgoing + } + + if outgoing_a == nil && outgoing_b == nil { + return 0 + } else if outgoing_a == nil && outgoing_b != nil { + return -1 + } else if outgoing_a != nil && outgoing_b == nil { + return -1 + } else if *outgoing_a < *outgoing_b { + return -1 + } else if *outgoing_a > *outgoing_b { + return 1 + } + return 0 + }) + subscriber = m.createSubscriber(ctx, listener, publisherInfo.id, publisher, streamType, publisherInfo.conn, connections2, func(c *mcuProxyConnection) bool { + return true + }) + } + if subscriber != nil { + return subscriber, nil + } + } + } + + subscriber, err := publisherInfo.conn.newSubscriber(ctx, listener, publisherInfo.id, publisher, streamType) + if err != nil { + if publisherInfo.conn.IsTemporary() { + publisherInfo.conn.closeIfEmpty() + } + log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, publisherInfo.conn, err) + return nil, err + } + + return subscriber, nil } diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index e518e6d..39f12a9 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -22,7 +22,26 @@ package signaling import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "sync" + "sync/atomic" "testing" + "time" + + "github.com/dlintw/goconf" + "github.com/gorilla/websocket" + "go.etcd.io/etcd/server/v3/embed" ) func TestMcuProxyStats(t *testing.T) { @@ -166,3 +185,1542 @@ func Test_sortConnectionsForCountryWithOverride(t *testing.T) { }) } } + +type proxyServerClientHandler func(msg *ProxyClientMessage) (*ProxyServerMessage, error) + +type testProxyServerPublisher struct { + id string +} + +type testProxyServerSubscriber struct { + id string + sid string + pub *testProxyServerPublisher + + remoteUrl string +} + +type testProxyServerClient struct { + t *testing.T + + server *TestProxyServerHandler + ws *websocket.Conn + processMessage proxyServerClientHandler + + mu sync.Mutex + sessionId string +} + +func (c *testProxyServerClient) processHello(msg *ProxyClientMessage) (*ProxyServerMessage, error) { + if msg.Type != "hello" { + return nil, fmt.Errorf("expected hello, got %+v", msg) + } + + response := &ProxyServerMessage{ + Id: msg.Id, + Type: "hello", + Hello: &HelloProxyServerMessage{ + Version: "1.0", + SessionId: c.sessionId, + Server: &WelcomeServerMessage{ + Version: "1.0", + Country: c.server.country, + }, + }, + } + c.processMessage = c.processRegularMessage + return response, nil +} + +func (c *testProxyServerClient) processRegularMessage(msg *ProxyClientMessage) (*ProxyServerMessage, error) { + var handler proxyServerClientHandler + switch msg.Type { + case "command": + handler = c.processCommandMessage + } + + if handler == nil { + response := msg.NewWrappedErrorServerMessage(fmt.Errorf("type \"%s\" is not implemented", msg.Type)) + return response, nil + } + + return handler(msg) +} + +func (c *testProxyServerClient) processCommandMessage(msg *ProxyClientMessage) (*ProxyServerMessage, error) { + var response *ProxyServerMessage + switch msg.Command.Type { + case "create-publisher": + pub := c.server.createPublisher() + + response = &ProxyServerMessage{ + Id: msg.Id, + Type: "command", + Command: &CommandProxyServerMessage{ + Id: pub.id, + Bitrate: msg.Command.Bitrate, + }, + } + c.server.updateLoad(1) + case "delete-publisher": + if pub, found := c.server.deletePublisher(msg.Command.ClientId); !found { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("publisher %s not found", msg.Command.ClientId)) + } else { + response = &ProxyServerMessage{ + Id: msg.Id, + Type: "command", + Command: &CommandProxyServerMessage{ + Id: pub.id, + }, + } + c.server.updateLoad(-1) + } + case "create-subscriber": + var pub *testProxyServerPublisher + if msg.Command.RemoteUrl != "" { + for _, server := range c.server.servers { + if server.URL != msg.Command.RemoteUrl { + continue + } + + pub = server.getPublisher(msg.Command.PublisherId) + break + } + } else { + pub = c.server.getPublisher(msg.Command.PublisherId) + } + + if pub == nil { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("publisher %s not found", msg.Command.PublisherId)) + } else { + sub := c.server.createSubscriber(pub) + response = &ProxyServerMessage{ + Id: msg.Id, + Type: "command", + Command: &CommandProxyServerMessage{ + Id: sub.id, + Sid: sub.sid, + }, + } + c.server.updateLoad(1) + } + case "delete-subscriber": + if sub, found := c.server.deleteSubscriber(msg.Command.ClientId); !found { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("subscriber %s not found", msg.Command.ClientId)) + } else { + if msg.Command.RemoteUrl != sub.remoteUrl { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("remote subscriber %s not found", msg.Command.ClientId)) + return response, nil + } + + response = &ProxyServerMessage{ + Id: msg.Id, + Type: "command", + Command: &CommandProxyServerMessage{ + Id: sub.id, + }, + } + c.server.updateLoad(-1) + } + } + if response == nil { + response = msg.NewWrappedErrorServerMessage(fmt.Errorf("command \"%s\" is not implemented", msg.Command.Type)) + } + + return response, nil +} + +func (c *testProxyServerClient) close() { + c.mu.Lock() + defer c.mu.Unlock() + + c.ws.Close() + c.ws = nil +} + +func (c *testProxyServerClient) handleSendMessageError(fmt string, msg *ProxyServerMessage, err error) { + c.t.Helper() + + if !errors.Is(err, websocket.ErrCloseSent) || msg.Type != "event" || msg.Event.Type != "update-load" { + c.t.Errorf(fmt, msg, err) + } +} + +func (c *testProxyServerClient) sendMessage(msg *ProxyServerMessage) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.ws == nil { + return + } + + data, err := json.Marshal(msg) + if err != nil { + c.handleSendMessageError("error marshalling %+v: %s", msg, err) + return + } + + w, err := c.ws.NextWriter(websocket.TextMessage) + if err != nil { + c.handleSendMessageError("error creating writer for %+v: %s", msg, err) + return + } + + if _, err := w.Write(data); err != nil { + c.handleSendMessageError("error sending %+v: %s", msg, err) + return + } + + if err := w.Close(); err != nil { + c.handleSendMessageError("error during close of sending %+v: %s", msg, err) + } +} + +func (c *testProxyServerClient) run() { + defer func() { + c.mu.Lock() + defer c.mu.Unlock() + + c.server.removeClient(c) + c.ws = nil + }() + c.processMessage = c.processHello + for { + c.mu.Lock() + ws := c.ws + c.mu.Unlock() + if ws == nil { + break + } + + msgType, reader, err := ws.NextReader() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { + c.t.Error(err) + } + return + } + + body, err := io.ReadAll(reader) + if err != nil { + c.t.Error(err) + continue + } + + if msgType != websocket.TextMessage { + c.t.Errorf("unexpected message type %q (%s)", msgType, string(body)) + continue + } + + var msg ProxyClientMessage + if err := json.Unmarshal(body, &msg); err != nil { + c.t.Errorf("could not decode message %s: %s", string(body), err) + continue + } + + if err := msg.CheckValid(); err != nil { + c.t.Errorf("invalid message %s: %s", string(body), err) + continue + } + + response, err := c.processMessage(&msg) + if err != nil { + c.t.Error(err) + continue + } + + c.sendMessage(response) + if response.Type == "hello" { + c.server.sendLoad(c) + } + } +} + +type TestProxyServerHandler struct { + t *testing.T + + URL string + server *httptest.Server + servers []*TestProxyServerHandler + upgrader *websocket.Upgrader + country string + + mu sync.Mutex + load atomic.Int64 + incoming atomic.Pointer[float64] + outgoing atomic.Pointer[float64] + clients map[string]*testProxyServerClient + publishers map[string]*testProxyServerPublisher + subscribers map[string]*testProxyServerSubscriber +} + +func (h *TestProxyServerHandler) createPublisher() *testProxyServerPublisher { + h.mu.Lock() + defer h.mu.Unlock() + pub := &testProxyServerPublisher{ + id: newRandomString(32), + } + + for { + if _, found := h.publishers[pub.id]; !found { + break + } + + pub.id = newRandomString(32) + } + h.publishers[pub.id] = pub + return pub +} + +func (h *TestProxyServerHandler) getPublisher(id string) *testProxyServerPublisher { + h.mu.Lock() + defer h.mu.Unlock() + + return h.publishers[id] +} + +func (h *TestProxyServerHandler) deletePublisher(id string) (*testProxyServerPublisher, bool) { + h.mu.Lock() + defer h.mu.Unlock() + + pub, found := h.publishers[id] + if !found { + return nil, false + } + + delete(h.publishers, id) + return pub, true +} + +func (h *TestProxyServerHandler) createSubscriber(pub *testProxyServerPublisher) *testProxyServerSubscriber { + h.mu.Lock() + defer h.mu.Unlock() + + sub := &testProxyServerSubscriber{ + id: newRandomString(32), + sid: newRandomString(8), + pub: pub, + } + + for { + if _, found := h.subscribers[sub.id]; !found { + break + } + + sub.id = newRandomString(32) + } + h.subscribers[sub.id] = sub + return sub +} + +func (h *TestProxyServerHandler) deleteSubscriber(id string) (*testProxyServerSubscriber, bool) { + h.mu.Lock() + defer h.mu.Unlock() + + sub, found := h.subscribers[id] + if !found { + return nil, false + } + + delete(h.subscribers, id) + return sub, true +} + +func (h *TestProxyServerHandler) UpdateBandwidth(incoming float64, outgoing float64) { + h.incoming.Store(&incoming) + h.outgoing.Store(&outgoing) + + h.mu.Lock() + defer h.mu.Unlock() + + msg := h.getLoadMessage(h.load.Load()) + for _, c := range h.clients { + c.sendMessage(msg) + } +} + +func (h *TestProxyServerHandler) Clear(incoming bool, outgoing bool) { + if incoming { + h.incoming.Store(nil) + } + if outgoing { + h.outgoing.Store(nil) + } + + h.mu.Lock() + defer h.mu.Unlock() + + msg := h.getLoadMessage(h.load.Load()) + for _, c := range h.clients { + c.sendMessage(msg) + } +} + +func (h *TestProxyServerHandler) getLoadMessage(load int64) *ProxyServerMessage { + msg := &ProxyServerMessage{ + Type: "event", + Event: &EventProxyServerMessage{ + Type: "update-load", + Load: load, + }, + } + + incoming := h.incoming.Load() + outgoing := h.outgoing.Load() + if incoming != nil || outgoing != nil { + msg.Event.Bandwidth = &EventProxyServerBandwidth{ + Incoming: incoming, + Outgoing: outgoing, + } + } + return msg +} + +func (h *TestProxyServerHandler) updateLoad(delta int64) { + if delta == 0 { + return + } + + load := h.load.Add(delta) + + h.mu.Lock() + defer h.mu.Unlock() + + msg := h.getLoadMessage(load) + for _, c := range h.clients { + go c.sendMessage(msg) + } +} + +func (h *TestProxyServerHandler) sendLoad(c *testProxyServerClient) { + msg := h.getLoadMessage(h.load.Load()) + c.sendMessage(msg) +} + +func (h *TestProxyServerHandler) removeClient(client *testProxyServerClient) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.clients, client.sessionId) +} + +func (h *TestProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ws, err := h.upgrader.Upgrade(w, r, nil) + if err != nil { + h.t.Error(err) + return + } + + client := &testProxyServerClient{ + t: h.t, + server: h, + ws: ws, + sessionId: newRandomString(32), + } + + h.mu.Lock() + h.clients[client.sessionId] = client + h.mu.Unlock() + + go client.run() +} + +func NewProxyServerForTest(t *testing.T, country string) *TestProxyServerHandler { + t.Helper() + + upgrader := websocket.Upgrader{} + proxyHandler := &TestProxyServerHandler{ + t: t, + upgrader: &upgrader, + country: country, + clients: make(map[string]*testProxyServerClient), + publishers: make(map[string]*testProxyServerPublisher), + subscribers: make(map[string]*testProxyServerSubscriber), + } + server := httptest.NewServer(proxyHandler) + proxyHandler.server = server + proxyHandler.URL = server.URL + t.Cleanup(func() { + server.Close() + proxyHandler.mu.Lock() + defer proxyHandler.mu.Unlock() + for _, c := range proxyHandler.clients { + c.close() + } + }) + + return proxyHandler +} + +type proxyTestOptions struct { + etcd *embed.Etcd + servers []*TestProxyServerHandler +} + +func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuProxy { + t.Helper() + if options.etcd == nil { + options.etcd = NewEtcdForTest(t) + } + grpcClients, dnsMonitor := NewGrpcClientsWithEtcdForTest(t, options.etcd) + + tokenKey, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + dir := t.TempDir() + privkeyFile := path.Join(dir, "privkey.pem") + pubkeyFile := path.Join(dir, "pubkey.pem") + WritePrivateKey(tokenKey, privkeyFile) // nolint + WritePublicKey(&tokenKey.PublicKey, pubkeyFile) // nolint + + cfg := goconf.NewConfigFile() + cfg.AddOption("mcu", "urltype", "static") + var urls []string + waitingMap := make(map[string]bool) + if len(options.servers) == 0 { + options.servers = []*TestProxyServerHandler{ + NewProxyServerForTest(t, "DE"), + } + } + for _, s := range options.servers { + s.servers = options.servers + urls = append(urls, s.URL) + waitingMap[s.URL] = true + } + cfg.AddOption("mcu", "url", strings.Join(urls, " ")) + cfg.AddOption("mcu", "token_id", "test-token") + cfg.AddOption("mcu", "token_key", privkeyFile) + + etcdConfig := goconf.NewConfigFile() + etcdConfig.AddOption("etcd", "endpoints", options.etcd.Config().ListenClientUrls[0].String()) + etcdConfig.AddOption("etcd", "loglevel", "error") + + etcdClient, err := NewEtcdClient(etcdConfig, "") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := etcdClient.Close(); err != nil { + t.Error(err) + } + }) + + mcu, err := NewMcuProxy(cfg, etcdClient, grpcClients, dnsMonitor) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + mcu.Stop() + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if err := mcu.Start(ctx); err != nil { + t.Fatal(err) + } + + proxy := mcu.(*mcuProxy) + + if err := proxy.WaitForConnections(ctx); err != nil { + t.Fatal(err) + } + + for len(waitingMap) > 0 { + if err := ctx.Err(); err != nil { + t.Fatal(err) + } + + for u := range waitingMap { + proxy.connectionsMu.RLock() + connections := proxy.connections + proxy.connectionsMu.RUnlock() + for _, c := range connections { + if c.rawUrl == u && c.IsConnected() && c.SessionId() != "" { + delete(waitingMap, u) + break + } + } + } + + time.Sleep(time.Millisecond) + } + + return proxy +} + +func newMcuProxyForTestWithServers(t *testing.T, servers []*TestProxyServerHandler) *mcuProxy { + t.Helper() + + return newMcuProxyForTestWithOptions(t, proxyTestOptions{ + servers: servers, + }) +} + +func newMcuProxyForTest(t *testing.T) *mcuProxy { + t.Helper() + server := NewProxyServerForTest(t, "DE") + + return newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{server}) +} + +func Test_ProxyPublisherSubscriber(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + mcu := newMcuProxyForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "DE", + } + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) +} + +func Test_ProxyWaitForPublisher(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + mcu := newMcuProxyForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "DE", + } + done := make(chan struct{}) + go func() { + defer close(done) + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Error(err) + return + } + + defer sub.Close(context.Background()) + }() + + // Give subscriber goroutine some time to start + time.Sleep(100 * time.Millisecond) + + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + select { + case <-done: + case <-ctx.Done(): + t.Error(ctx.Err()) + } + defer pub.Close(context.Background()) +} + +func Test_ProxyPublisherBandwidth(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "DE") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + server1, + server2, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pub1Id := "the-publisher-1" + pub1Sid := "1234567890" + pub1Listener := &MockMcuListener{ + publicId: pub1Id + "-public", + } + pub1Initiator := &MockMcuInitiator{ + country: "DE", + } + pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator) + if err != nil { + t.Fatal(err) + } + + defer pub1.Close(context.Background()) + + if pub1.(*mcuProxyPublisher).conn.rawUrl == server1.URL { + server1.UpdateBandwidth(100, 0) + } else { + server2.UpdateBandwidth(100, 0) + } + + // Wait until proxy has been updated + for ctx.Err() == nil { + mcu.connectionsMu.RLock() + connections := mcu.connections + mcu.connectionsMu.RUnlock() + missing := true + for _, c := range connections { + if c.Bandwidth() != nil { + missing = false + break + } + } + if !missing { + break + } + time.Sleep(time.Millisecond) + } + + pub2Id := "the-publisher-2" + pub2id := "1234567890" + pub2Listener := &MockMcuListener{ + publicId: pub2Id + "-public", + } + pub2Initiator := &MockMcuInitiator{ + country: "DE", + } + pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator) + if err != nil { + t.Fatal(err) + } + + defer pub2.Close(context.Background()) + + if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl { + t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl) + } +} + +func Test_ProxyPublisherBandwidthOverload(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "DE") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + server1, + server2, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pub1Id := "the-publisher-1" + pub1Sid := "1234567890" + pub1Listener := &MockMcuListener{ + publicId: pub1Id + "-public", + } + pub1Initiator := &MockMcuInitiator{ + country: "DE", + } + pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator) + if err != nil { + t.Fatal(err) + } + + defer pub1.Close(context.Background()) + + // If all servers are bandwidth loaded, select the one with the least usage. + if pub1.(*mcuProxyPublisher).conn.rawUrl == server1.URL { + server1.UpdateBandwidth(100, 0) + server2.UpdateBandwidth(102, 0) + } else { + server1.UpdateBandwidth(102, 0) + server2.UpdateBandwidth(100, 0) + } + + // Wait until proxy has been updated + for ctx.Err() == nil { + mcu.connectionsMu.RLock() + connections := mcu.connections + mcu.connectionsMu.RUnlock() + missing := false + for _, c := range connections { + if c.Bandwidth() == nil { + missing = true + break + } + } + if !missing { + break + } + time.Sleep(time.Millisecond) + } + + pub2Id := "the-publisher-2" + pub2id := "1234567890" + pub2Listener := &MockMcuListener{ + publicId: pub2Id + "-public", + } + pub2Initiator := &MockMcuInitiator{ + country: "DE", + } + pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator) + if err != nil { + t.Fatal(err) + } + + defer pub2.Close(context.Background()) + + if pub1.(*mcuProxyPublisher).conn.rawUrl != pub2.(*mcuProxyPublisher).conn.rawUrl { + t.Errorf("servers should be the same, got %s / %s", pub1.(*mcuProxyPublisher).conn.rawUrl, pub2.(*mcuProxyPublisher).conn.rawUrl) + } +} + +func Test_ProxyPublisherLoad(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "DE") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + server1, + server2, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pub1Id := "the-publisher-1" + pub1Sid := "1234567890" + pub1Listener := &MockMcuListener{ + publicId: pub1Id + "-public", + } + pub1Initiator := &MockMcuInitiator{ + country: "DE", + } + pub1, err := mcu.NewPublisher(ctx, pub1Listener, pub1Id, pub1Sid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub1Initiator) + if err != nil { + t.Fatal(err) + } + + defer pub1.Close(context.Background()) + + // Make sure connections are re-sorted. + mcu.nextSort.Store(0) + time.Sleep(100 * time.Millisecond) + + pub2Id := "the-publisher-2" + pub2id := "1234567890" + pub2Listener := &MockMcuListener{ + publicId: pub2Id + "-public", + } + pub2Initiator := &MockMcuInitiator{ + country: "DE", + } + pub2, err := mcu.NewPublisher(ctx, pub2Listener, pub2Id, pub2id, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pub2Initiator) + if err != nil { + t.Fatal(err) + } + + defer pub2.Close(context.Background()) + + if pub1.(*mcuProxyPublisher).conn.rawUrl == pub2.(*mcuProxyPublisher).conn.rawUrl { + t.Errorf("servers should be different, got %s", pub1.(*mcuProxyPublisher).conn.rawUrl) + } +} + +func Test_ProxyPublisherCountry(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubDEId := "the-publisher-de" + pubDESid := "1234567890" + pubDEListener := &MockMcuListener{ + publicId: pubDEId + "-public", + } + pubDEInitiator := &MockMcuInitiator{ + country: "DE", + } + pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator) + if err != nil { + t.Fatal(err) + } + + defer pubDE.Close(context.Background()) + + if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl) + } + + pubUSId := "the-publisher-us" + pubUSSid := "1234567890" + pubUSListener := &MockMcuListener{ + publicId: pubUSId + "-public", + } + pubUSInitiator := &MockMcuInitiator{ + country: "US", + } + pubUS, err := mcu.NewPublisher(ctx, pubUSListener, pubUSId, pubUSSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubUSInitiator) + if err != nil { + t.Fatal(err) + } + + defer pubUS.Close(context.Background()) + + if pubUS.(*mcuProxyPublisher).conn.rawUrl != serverUS.URL { + t.Errorf("expected server %s, go %s", serverUS.URL, pubUS.(*mcuProxyPublisher).conn.rawUrl) + } +} + +func Test_ProxyPublisherContinent(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubDEId := "the-publisher-de" + pubDESid := "1234567890" + pubDEListener := &MockMcuListener{ + publicId: pubDEId + "-public", + } + pubDEInitiator := &MockMcuInitiator{ + country: "DE", + } + pubDE, err := mcu.NewPublisher(ctx, pubDEListener, pubDEId, pubDESid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubDEInitiator) + if err != nil { + t.Fatal(err) + } + + defer pubDE.Close(context.Background()) + + if pubDE.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pubDE.(*mcuProxyPublisher).conn.rawUrl) + } + + pubFRId := "the-publisher-fr" + pubFRSid := "1234567890" + pubFRListener := &MockMcuListener{ + publicId: pubFRId + "-public", + } + pubFRInitiator := &MockMcuInitiator{ + country: "FR", + } + pubFR, err := mcu.NewPublisher(ctx, pubFRListener, pubFRId, pubFRSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubFRInitiator) + if err != nil { + t.Fatal(err) + } + + defer pubFR.Close(context.Background()) + + if pubFR.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl) + } +} + +func Test_ProxySubscriberCountry(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) + } + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "US", + } + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) + + if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL { + t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl) + } +} + +func Test_ProxySubscriberContinent(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) + } + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "FR", + } + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) + + if sub.(*mcuProxySubscriber).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl) + } +} + +func Test_ProxySubscriberBandwidth(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) + } + + serverDE.UpdateBandwidth(0, 100) + + // Wait until proxy has been updated + for ctx.Err() == nil { + mcu.connectionsMu.RLock() + connections := mcu.connections + mcu.connectionsMu.RUnlock() + missing := true + for _, c := range connections { + if c.Bandwidth() != nil { + missing = false + break + } + } + if !missing { + break + } + time.Sleep(time.Millisecond) + } + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "US", + } + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) + + if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL { + t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl) + } +} + +func Test_ProxySubscriberBandwidthOverload(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + serverDE := NewProxyServerForTest(t, "DE") + serverUS := NewProxyServerForTest(t, "US") + mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{ + serverDE, + serverUS, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl) + } + + serverDE.UpdateBandwidth(0, 100) + serverUS.UpdateBandwidth(0, 102) + + // Wait until proxy has been updated + for ctx.Err() == nil { + mcu.connectionsMu.RLock() + connections := mcu.connections + mcu.connectionsMu.RUnlock() + missing := false + for _, c := range connections { + if c.Bandwidth() == nil { + missing = true + break + } + } + if !missing { + break + } + time.Sleep(time.Millisecond) + } + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "US", + } + sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) + + if sub.(*mcuProxySubscriber).conn.rawUrl != serverDE.URL { + t.Errorf("expected server %s, go %s", serverDE.URL, sub.(*mcuProxySubscriber).conn.rawUrl) + } +} + +type mockGrpcServerHub struct { + sessionsLock sync.Mutex + sessionByPublicId map[string]Session +} + +func (h *mockGrpcServerHub) addSession(session *ClientSession) { + h.sessionsLock.Lock() + defer h.sessionsLock.Unlock() + if h.sessionByPublicId == nil { + h.sessionByPublicId = make(map[string]Session) + } + h.sessionByPublicId[session.PublicId()] = session +} + +func (h *mockGrpcServerHub) removeSession(session *ClientSession) { + h.sessionsLock.Lock() + defer h.sessionsLock.Unlock() + delete(h.sessionByPublicId, session.PublicId()) +} + +func (h *mockGrpcServerHub) GetSessionByResumeId(resumeId string) Session { + return nil +} + +func (h *mockGrpcServerHub) GetSessionByPublicId(sessionId string) Session { + h.sessionsLock.Lock() + defer h.sessionsLock.Unlock() + return h.sessionByPublicId[sessionId] +} + +func (h *mockGrpcServerHub) GetSessionIdByRoomSessionId(roomSessionId string) (string, error) { + return "", nil +} + +func (h *mockGrpcServerHub) GetBackend(u *url.URL) *Backend { + return nil +} + +func Test_ProxyRemotePublisher(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + + etcd := NewEtcdForTest(t) + + grpcServer1, addr1 := NewGrpcServerForTest(t) + grpcServer2, addr2 := NewGrpcServerForTest(t) + + hub1 := &mockGrpcServerHub{} + hub2 := &mockGrpcServerHub{} + grpcServer1.hub = hub1 + grpcServer2.hub = hub2 + + SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "DE") + + mcu1 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + server2, + }, + }) + mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + server2, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + session1 := &ClientSession{ + publicId: pubId, + publishers: make(map[StreamType]McuPublisher), + } + hub1.addSession(session1) + defer hub1.removeSession(session1) + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + session1.mu.Lock() + session1.publishers[StreamTypeVideo] = pub + session1.publisherWaiters.Wakeup() + session1.mu.Unlock() + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "DE", + } + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) +} + +func Test_ProxyRemotePublisherWait(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + + etcd := NewEtcdForTest(t) + + grpcServer1, addr1 := NewGrpcServerForTest(t) + grpcServer2, addr2 := NewGrpcServerForTest(t) + + hub1 := &mockGrpcServerHub{} + hub2 := &mockGrpcServerHub{} + grpcServer1.hub = hub1 + grpcServer2.hub = hub2 + + SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "DE") + + mcu1 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + server2, + }, + }) + mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + server2, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + session1 := &ClientSession{ + publicId: pubId, + publishers: make(map[StreamType]McuPublisher), + } + hub1.addSession(session1) + defer hub1.removeSession(session1) + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "DE", + } + + done := make(chan struct{}) + go func() { + defer close(done) + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Error(err) + return + } + + defer sub.Close(context.Background()) + }() + + // Give subscriber goroutine some time to start + time.Sleep(100 * time.Millisecond) + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + session1.mu.Lock() + session1.publishers[StreamTypeVideo] = pub + session1.publisherWaiters.Wakeup() + session1.mu.Unlock() + + select { + case <-done: + case <-ctx.Done(): + t.Error(ctx.Err()) + } +} + +func Test_ProxyRemotePublisherTemporary(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + + etcd := NewEtcdForTest(t) + + grpcServer1, addr1 := NewGrpcServerForTest(t) + grpcServer2, addr2 := NewGrpcServerForTest(t) + + hub1 := &mockGrpcServerHub{} + hub2 := &mockGrpcServerHub{} + grpcServer1.hub = hub1 + grpcServer2.hub = hub2 + + SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "DE") + + mcu1 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + }, + }) + mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server2, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + session1 := &ClientSession{ + publicId: pubId, + publishers: make(map[StreamType]McuPublisher), + } + hub1.addSession(session1) + defer hub1.removeSession(session1) + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator) + if err != nil { + t.Fatal(err) + } + + defer pub.Close(context.Background()) + + session1.mu.Lock() + session1.publishers[StreamTypeVideo] = pub + session1.publisherWaiters.Wakeup() + session1.mu.Unlock() + + mcu2.connectionsMu.RLock() + count := len(mcu2.connections) + mcu2.connectionsMu.RUnlock() + if expected := 1; count != expected { + t.Errorf("expected %d connections, got %+v", expected, count) + } + + subListener := &MockMcuListener{ + publicId: "subscriber-public", + } + subInitiator := &MockMcuInitiator{ + country: "DE", + } + sub, err := mcu2.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator) + if err != nil { + t.Fatal(err) + } + + defer sub.Close(context.Background()) + + if sub.(*mcuProxySubscriber).conn.rawUrl != server1.URL { + t.Errorf("expected server %s, go %s", server1.URL, sub.(*mcuProxySubscriber).conn.rawUrl) + } + + // The temporary connection has been added + mcu2.connectionsMu.RLock() + count = len(mcu2.connections) + mcu2.connectionsMu.RUnlock() + if expected := 2; count != expected { + t.Errorf("expected %d connections, got %+v", expected, count) + } + + sub.Close(context.Background()) + + // Wait for temporary connection to be removed. +loop: + for { + select { + case <-ctx.Done(): + t.Error(ctx.Err()) + default: + mcu2.connectionsMu.RLock() + count = len(mcu2.connections) + mcu2.connectionsMu.RUnlock() + if count == 1 { + break loop + } + } + } +} diff --git a/mcu_test.go b/mcu_test.go index 903a2bc..ae1de23 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -23,6 +23,7 @@ package signaling import ( "context" + "errors" "fmt" "log" "sync" @@ -49,7 +50,7 @@ func NewTestMCU() (*TestMCU, error) { }, nil } -func (m *TestMCU) Start() error { +func (m *TestMCU) Start(ctx context.Context) error { return nil } @@ -117,7 +118,7 @@ func (m *TestMCU) GetPublisher(id string) *TestMCUPublisher { return m.publishers[id] } -func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) { +func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) { m.mu.Lock() defer m.mu.Unlock() @@ -222,6 +223,18 @@ func (p *TestMCUPublisher) SendMessage(ctx context.Context, message *MessageClie }() } +func (p *TestMCUPublisher) GetStreams(ctx context.Context) ([]PublisherStream, error) { + return nil, errors.New("not implemented") +} + +func (p *TestMCUPublisher) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { + return errors.New("remote publishing not supported") +} + +func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string) error { + return errors.New("remote publishing not supported") +} + type TestMCUSubscriber struct { TestMCUClient diff --git a/proxy.conf.in b/proxy.conf.in index 0fcb350..0f33cf5 100644 --- a/proxy.conf.in +++ b/proxy.conf.in @@ -25,6 +25,36 @@ # - etcd: Token information are retrieved from an etcd cluster (see below). tokentype = static +# The external hostname for remote streams. Leaving this empty will autodetect +# and use the first public IP found on the available network interfaces. +#hostname = + +# The token id to use when connecting remote stream. +#token_id = server1 + +# The private key for the configured token id to use when connecting remote +# streams. +#token_key = privkey.pem + +# If set to "true", certificate validation of remote stream requests will be +# skipped. This should only be enabled during development, e.g. to work with +# self-signed certificates. +#skipverify = false + +[bandwidth] +# Target bandwidth limit for incoming streams (in megabits per second). +# Set to 0 to disable the limit. If the limit is reached, the proxy notifies +# the signaling servers that another proxy should be used for publishing if +# possible. +#incoming = 1024 + +# Target bandwidth limit for outgoing streams (in megabits per second). +# Set to 0 to disable the limit. If the limit is reached, the proxy notifies +# the signaling servers that another proxy should be used for subscribing if +# possible. Note that this might require additional outgoing bandwidth for the +# remote streams. +#outgoing = 1024 + [tokens] # For token type "static": Mapping of = of signaling # servers allowed to connect. diff --git a/proxy/proxy_remote.go b/proxy/proxy_remote.go new file mode 100644 index 0000000..838cecc --- /dev/null +++ b/proxy/proxy_remote.go @@ -0,0 +1,490 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "context" + "crypto/rsa" + "crypto/tls" + "encoding/json" + "errors" + "log" + "net/http" + "net/url" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/gorilla/websocket" + + signaling "github.com/strukturag/nextcloud-spreed-signaling" +) + +const ( + initialReconnectInterval = 1 * time.Second + maxReconnectInterval = 32 * time.Second + + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 +) + +var ( + ErrNotConnected = errors.New("not connected") +) + +type RemoteConnection struct { + mu sync.Mutex + url *url.URL + conn *websocket.Conn + closer *signaling.Closer + closed atomic.Bool + + tokenId string + tokenKey *rsa.PrivateKey + tlsConfig *tls.Config + + connectedSince time.Time + reconnectTimer *time.Timer + reconnectInterval atomic.Int64 + + msgId atomic.Int64 + helloMsgId string + sessionId string + + pendingMessages []*signaling.ProxyClientMessage + messageCallbacks map[string]chan *signaling.ProxyServerMessage +} + +func NewRemoteConnection(proxyUrl string, tokenId string, tokenKey *rsa.PrivateKey, tlsConfig *tls.Config) (*RemoteConnection, error) { + u, err := url.Parse(proxyUrl) + if err != nil { + return nil, err + } + + result := &RemoteConnection{ + url: u, + closer: signaling.NewCloser(), + + tokenId: tokenId, + tokenKey: tokenKey, + tlsConfig: tlsConfig, + + reconnectTimer: time.NewTimer(0), + + messageCallbacks: make(map[string]chan *signaling.ProxyServerMessage), + } + result.reconnectInterval.Store(int64(initialReconnectInterval)) + + go result.writePump() + + return result, nil +} + +func (c *RemoteConnection) String() string { + return c.url.String() +} + +func (c *RemoteConnection) reconnect() { + u, err := c.url.Parse("proxy") + if err != nil { + log.Printf("Could not resolve url to proxy at %s: %s", c, err) + c.scheduleReconnect() + return + } + if u.Scheme == "http" { + u.Scheme = "ws" + } else if u.Scheme == "https" { + u.Scheme = "wss" + } + + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: c.tlsConfig, + } + + conn, _, err := dialer.DialContext(context.TODO(), u.String(), nil) + if err != nil { + log.Printf("Error connecting to proxy at %s: %s", c, err) + c.scheduleReconnect() + return + } + + log.Printf("Connected to %s", c) + c.closed.Store(false) + + c.mu.Lock() + c.connectedSince = time.Now() + c.conn = conn + c.mu.Unlock() + + c.reconnectInterval.Store(int64(initialReconnectInterval)) + + if err := c.sendHello(); err != nil { + log.Printf("Error sending hello request to proxy at %s: %s", c, err) + c.scheduleReconnect() + return + } + + if !c.sendPing() { + return + } + + go c.readPump(conn) +} + +func (c *RemoteConnection) scheduleReconnect() { + if err := c.sendClose(); err != nil && err != ErrNotConnected { + log.Printf("Could not send close message to %s: %s", c, err) + } + c.close() + + interval := c.reconnectInterval.Load() + c.reconnectTimer.Reset(time.Duration(interval)) + + interval = interval * 2 + if interval > int64(maxReconnectInterval) { + interval = int64(maxReconnectInterval) + } + c.reconnectInterval.Store(interval) +} + +func (c *RemoteConnection) sendHello() error { + c.helloMsgId = strconv.FormatInt(c.msgId.Add(1), 10) + msg := &signaling.ProxyClientMessage{ + Id: c.helloMsgId, + Type: "hello", + Hello: &signaling.HelloProxyClientMessage{ + Version: "1.0", + }, + } + if sessionId := c.sessionId; sessionId != "" { + msg.Hello.ResumeId = sessionId + } else { + tokenString, err := c.createToken("") + if err != nil { + return err + } + + msg.Hello.Token = tokenString + } + + return c.SendMessage(msg) +} + +func (c *RemoteConnection) sendClose() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return ErrNotConnected + } + + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint + return c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) +} + +func (c *RemoteConnection) close() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } +} + +func (c *RemoteConnection) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.reconnectTimer.Stop() + if c.conn == nil { + return nil + } + + c.sendClose() + err1 := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}) + err2 := c.conn.Close() + c.conn = nil + if err1 != nil { + return err1 + } + return err2 +} + +func (c *RemoteConnection) createToken(subject string) (string, error) { + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: c.tokenId, + Subject: subject, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(c.tokenKey) + if err != nil { + return "", err + } + + return tokenString, nil +} + +func (c *RemoteConnection) SendMessage(msg *signaling.ProxyClientMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + return c.sendMessageLocked(context.Background(), msg) +} + +func (c *RemoteConnection) deferMessage(ctx context.Context, msg *signaling.ProxyClientMessage) { + c.pendingMessages = append(c.pendingMessages, msg) + if ctx.Done() != nil { + go func() { + <-ctx.Done() + + c.mu.Lock() + defer c.mu.Unlock() + for idx, m := range c.pendingMessages { + if m == msg { + c.pendingMessages[idx] = nil + break + } + } + }() + } +} + +func (c *RemoteConnection) sendMessageLocked(ctx context.Context, msg *signaling.ProxyClientMessage) error { + if c.conn == nil { + // Defer until connected. + c.deferMessage(ctx, msg) + return nil + } + + if c.helloMsgId != "" && c.helloMsgId != msg.Id { + // Hello request is still inflight, defer. + c.deferMessage(ctx, msg) + return nil + } + + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint + return c.conn.WriteJSON(msg) +} + +func (c *RemoteConnection) readPump(conn *websocket.Conn) { + defer func() { + if !c.closed.Load() { + c.scheduleReconnect() + } + }() + defer c.close() + + for { + msgType, msg, err := conn.ReadMessage() + if err != nil { + if errors.Is(err, websocket.ErrCloseSent) { + break + } else if _, ok := err.(*websocket.CloseError); !ok || websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived) { + log.Printf("Error reading from %s: %v", c, err) + } + break + } + + if msgType != websocket.TextMessage { + log.Printf("unexpected message type %q (%s)", msgType, string(msg)) + continue + } + + var message signaling.ProxyServerMessage + if err := json.Unmarshal(msg, &message); err != nil { + log.Printf("could not decode message %s: %s", string(msg), err) + continue + } + + c.mu.Lock() + helloMsgId := c.helloMsgId + c.mu.Unlock() + + if helloMsgId != "" && message.Id == helloMsgId { + c.processHello(&message) + } else { + c.processMessage(&message) + } + } +} + +func (c *RemoteConnection) sendPing() bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.conn == nil { + return false + } + + now := time.Now() + msg := strconv.FormatInt(now.UnixNano(), 10) + c.conn.SetWriteDeadline(now.Add(writeWait)) // nolint + if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { + log.Printf("Could not send ping to proxy at %s: %v", c, err) + go c.scheduleReconnect() + return false + } + + return true +} + +func (c *RemoteConnection) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + }() + + defer c.reconnectTimer.Stop() + for { + select { + case <-c.reconnectTimer.C: + c.reconnect() + case <-ticker.C: + c.sendPing() + case <-c.closer.C: + return + } + } +} + +func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { + c.helloMsgId = "" + switch msg.Type { + case "error": + if msg.Error.Code == "no_such_session" { + log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c) + c.sessionId = "" + if err := c.sendHello(); err != nil { + log.Printf("Could not send hello request to %s: %s", c, err) + c.scheduleReconnect() + } + return + } + + log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) + c.scheduleReconnect() + case "hello": + resumed := c.sessionId == msg.Hello.SessionId + c.sessionId = msg.Hello.SessionId + country := "" + if msg.Hello.Server != nil { + if country = msg.Hello.Server.Country; country != "" && !signaling.IsValidCountry(country) { + log.Printf("Proxy %s sent invalid country %s in hello response", c, country) + country = "" + } + } + if resumed { + log.Printf("Resumed session %s on %s", c.sessionId, c) + } else if country != "" { + log.Printf("Received session %s from %s (in %s)", c.sessionId, c, country) + } else { + log.Printf("Received session %s from %s", c.sessionId, c) + } + + pending := c.pendingMessages + c.pendingMessages = nil + for _, m := range pending { + if m == nil { + continue + } + + if err := c.sendMessageLocked(context.Background(), m); err != nil { + log.Printf("Could not send pending message %+v to %s: %s", m, c, err) + } + } + default: + log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c) + c.scheduleReconnect() + } +} + +func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) { + if msg.Id != "" { + c.mu.Lock() + ch, found := c.messageCallbacks[msg.Id] + if found { + delete(c.messageCallbacks, msg.Id) + c.mu.Unlock() + ch <- msg + return + } + c.mu.Unlock() + } + + switch msg.Type { + case "event": + c.processEvent(msg) + default: + log.Printf("Received unsupported message %+v from %s", msg, c) + } +} + +func (c *RemoteConnection) processEvent(msg *signaling.ProxyServerMessage) { + switch msg.Event.Type { + case "update-load": + default: + log.Printf("Received unsupported event %+v from %s", msg, c) + } +} + +func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.ProxyClientMessage) (*signaling.ProxyServerMessage, error) { + msg.Id = strconv.FormatInt(c.msgId.Add(1), 10) + + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.sendMessageLocked(ctx, msg); err != nil { + return nil, err + } + ch := make(chan *signaling.ProxyServerMessage, 1) + c.messageCallbacks[msg.Id] = ch + c.mu.Unlock() + defer func() { + c.mu.Lock() + delete(c.messageCallbacks, msg.Id) + }() + + select { + case <-ctx.Done(): + // TODO: Cancel request. + return nil, ctx.Err() + case response := <-ch: + if response.Type == "error" { + return nil, response.Error + } + return response, nil + } +} diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index feaf306..1d0d4fe 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -24,7 +24,10 @@ package main import ( "context" "crypto/rand" + "crypto/rsa" + "crypto/tls" "encoding/json" + "errors" "fmt" "io" "log" @@ -45,6 +48,7 @@ import ( "github.com/gorilla/mux" "github.com/gorilla/securecookie" "github.com/gorilla/websocket" + "github.com/notedit/janus-go" "github.com/prometheus/client_golang/prometheus/promhttp" signaling "github.com/strukturag/nextcloud-spreed-signaling" @@ -63,6 +67,8 @@ const ( // Maximum age a token may have to prevent reuse of old tokens. maxTokenAge = 5 * time.Minute + + remotePublisherTimeout = 5 * time.Second ) type ContextKey string @@ -70,28 +76,35 @@ type ContextKey string var ( ContextKeySession = ContextKey("session") - TimeoutCreatingPublisher = signaling.NewError("timeout", "Timeout creating publisher.") - TimeoutCreatingSubscriber = signaling.NewError("timeout", "Timeout creating subscriber.") - TokenAuthFailed = signaling.NewError("auth_failed", "The token could not be authenticated.") - TokenExpired = signaling.NewError("token_expired", "The token is expired.") - TokenNotValidYet = signaling.NewError("token_not_valid_yet", "The token is not valid yet.") - UnknownClient = signaling.NewError("unknown_client", "Unknown client id given.") - UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.") - UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.") - UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.") - ShutdownScheduled = signaling.NewError("shutdown_scheduled", "The server is scheduled to shutdown.") + TimeoutCreatingPublisher = signaling.NewError("timeout", "Timeout creating publisher.") + TimeoutCreatingSubscriber = signaling.NewError("timeout", "Timeout creating subscriber.") + TokenAuthFailed = signaling.NewError("auth_failed", "The token could not be authenticated.") + TokenExpired = signaling.NewError("token_expired", "The token is expired.") + TokenNotValidYet = signaling.NewError("token_not_valid_yet", "The token is not valid yet.") + UnknownClient = signaling.NewError("unknown_client", "Unknown client id given.") + UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.") + UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.") + UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.") + ShutdownScheduled = signaling.NewError("shutdown_scheduled", "The server is scheduled to shutdown.") + RemoteSubscribersNotSupported = signaling.NewError("unsupported_subscriber", "Remote subscribers are not supported.") ) type ProxyServer struct { version string country string welcomeMessage string + config *goconf.ConfigFile url string mcu signaling.Mcu stopped atomic.Bool load atomic.Int64 + maxIncoming int64 + currentIncoming atomic.Int64 + maxOutgoing int64 + currentOutgoing atomic.Int64 + shutdownChannel chan struct{} shutdownScheduled atomic.Bool @@ -109,6 +122,48 @@ type ProxyServer struct { clients map[string]signaling.McuClient clientIds map[string]string clientsLock sync.RWMutex + + tokenId string + tokenKey *rsa.PrivateKey + remoteTlsConfig *tls.Config + remoteHostname string + remoteConnections map[string]*RemoteConnection + remoteConnectionsLock sync.Mutex +} + +func IsPublicIP(IP net.IP) bool { + if IP.IsLoopback() || IP.IsLinkLocalMulticast() || IP.IsLinkLocalUnicast() { + return false + } + if ip4 := IP.To4(); ip4 != nil { + switch { + case ip4[0] == 10: + return false + case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31: + return false + case ip4[0] == 192 && ip4[1] == 168: + return false + default: + return true + } + } + return false +} + +func GetLocalIP() (string, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", err + } + + for _, address := range addrs { + if ipnet, ok := address.(*net.IPNet); ok && IsPublicIP(ipnet.IP) { + if ipnet.IP.To4() != nil { + return ipnet.IP.String(), nil + } + } + } + return "", nil } func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*ProxyServer, error) { @@ -187,10 +242,75 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* return nil, err } + tokenId, _ := config.GetString("app", "token_id") + var tokenKey *rsa.PrivateKey + var remoteHostname string + var remoteTlsConfig *tls.Config + if tokenId != "" { + tokenKeyFilename, _ := config.GetString("app", "token_key") + if tokenKeyFilename == "" { + return nil, fmt.Errorf("No token key configured") + } + tokenKeyData, err := os.ReadFile(tokenKeyFilename) + if err != nil { + return nil, fmt.Errorf("Could not read private key from %s: %s", tokenKeyFilename, err) + } + tokenKey, err = jwt.ParseRSAPrivateKeyFromPEM(tokenKeyData) + if err != nil { + return nil, fmt.Errorf("Could not parse private key from %s: %s", tokenKeyFilename, err) + } + log.Printf("Using \"%s\" as token id for remote streams", tokenId) + + remoteHostname, _ = config.GetString("app", "hostname") + if remoteHostname == "" { + remoteHostname, err = GetLocalIP() + if err != nil { + return nil, fmt.Errorf("could not get local ip: %w", err) + } + } + if remoteHostname == "" { + log.Printf("WARNING: Could not determine hostname for remote streams, will be disabled. Please configure manually.") + } else { + log.Printf("Using \"%s\" as hostname for remote streams", remoteHostname) + } + + skipverify, _ := config.GetBool("backend", "skipverify") + if skipverify { + log.Println("WARNING: Remote stream requests verification is disabled!") + remoteTlsConfig = &tls.Config{ + InsecureSkipVerify: skipverify, + } + } + } else { + log.Printf("No token id configured, remote streams will be disabled") + } + + maxIncoming, _ := config.GetInt("bandwidth", "incoming") + if maxIncoming < 0 { + maxIncoming = 0 + } + if maxIncoming > 0 { + log.Printf("Target bandwidth for incoming streams: %d MBit/s", maxIncoming) + } else { + log.Printf("Target bandwidth for incoming streams: unlimited") + } + maxOutgoing, _ := config.GetInt("bandwidth", "outgoing") + if maxOutgoing < 0 { + maxOutgoing = 0 + } + if maxIncoming > 0 { + log.Printf("Target bandwidth for outgoing streams: %d MBit/s", maxOutgoing) + } else { + log.Printf("Target bandwidth for outgoing streams: unlimited") + } + result := &ProxyServer{ version: version, country: country, welcomeMessage: string(welcomeMessage) + "\n", + config: config, + maxIncoming: int64(maxIncoming) * 1024 * 1024, + maxOutgoing: int64(maxOutgoing) * 1024 * 1024, shutdownChannel: make(chan struct{}), @@ -208,6 +328,12 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* clients: make(map[string]signaling.McuClient), clientIds: make(map[string]string), + + tokenId: tokenId, + tokenKey: tokenKey, + remoteTlsConfig: remoteTlsConfig, + remoteHostname: remoteHostname, + remoteConnections: make(map[string]*RemoteConnection), } result.upgrader.CheckOrigin = result.checkOrigin @@ -260,7 +386,7 @@ func (s *ProxyServer) Start(config *goconf.ConfigFile) error { for { switch mcuType { case signaling.McuTypeJanus: - mcu, err = signaling.NewMcuJanus(s.url, config) + mcu, err = signaling.NewMcuJanus(ctx, s.url, config) if err == nil { signaling.RegisterJanusMcuStats() } @@ -270,7 +396,7 @@ func (s *ProxyServer) Start(config *goconf.ConfigFile) error { if err == nil { mcu.SetOnConnected(s.onMcuConnected) mcu.SetOnDisconnected(s.onMcuDisconnected) - err = mcu.Start() + err = mcu.Start(ctx) if err != nil { log.Printf("Could not create %s MCU at %s: %s", mcuType, s.url, err) } @@ -313,18 +439,7 @@ loop: } } -func (s *ProxyServer) updateLoad() { - load := s.GetClientsLoad() - if load == s.load.Load() { - return - } - - s.load.Store(load) - if s.shutdownScheduled.Load() { - // Server is scheduled to shutdown, no need to update clients with current load. - return - } - +func (s *ProxyServer) newLoadEvent(load int64, incoming int64, outgoing int64) *signaling.ProxyServerMessage { msg := &signaling.ProxyServerMessage{ Type: "event", Event: &signaling.EventProxyServerMessage{ @@ -332,7 +447,37 @@ func (s *ProxyServer) updateLoad() { Load: load, }, } + if s.maxIncoming > 0 || s.maxOutgoing > 0 { + msg.Event.Bandwidth = &signaling.EventProxyServerBandwidth{} + if s.maxIncoming > 0 { + value := float64(incoming) / float64(s.maxIncoming) * 100 + msg.Event.Bandwidth.Incoming = &value + } + if s.maxOutgoing > 0 { + value := float64(outgoing) / float64(s.maxOutgoing) * 100 + msg.Event.Bandwidth.Outgoing = &value + } + } + return msg +} +func (s *ProxyServer) updateLoad() { + load, incoming, outgoing := s.GetClientsLoad() + if load == s.load.Load() && + incoming == s.currentIncoming.Load() && + outgoing == s.currentOutgoing.Load() { + return + } + + s.load.Store(load) + s.currentIncoming.Store(incoming) + s.currentOutgoing.Store(outgoing) + if s.shutdownScheduled.Load() { + // Server is scheduled to shutdown, no need to update clients with current load. + return + } + + msg := s.newLoadEvent(load, incoming, outgoing) s.IterateSessions(func(session *ProxySession) { session.sendMessage(msg) }) @@ -476,13 +621,7 @@ func (s *ProxyServer) onMcuDisconnected() { } func (s *ProxyServer) sendCurrentLoad(session *ProxySession) { - msg := &signaling.ProxyServerMessage{ - Type: "event", - Event: &signaling.EventProxyServerMessage{ - Type: "update-load", - Load: s.load.Load(), - }, - } + msg := s.newLoadEvent(s.load.Load(), s.currentIncoming.Load(), s.currentOutgoing.Load()) session.sendMessage(msg) } @@ -610,6 +749,59 @@ func (i *emptyInitiator) Country() string { return "" } +type proxyRemotePublisher struct { + proxy *ProxyServer + remoteUrl string + + publisherId string +} + +func (p *proxyRemotePublisher) PublisherId() string { + return p.publisherId +} + +func (p *proxyRemotePublisher) StartPublishing(ctx context.Context, publisher signaling.McuRemotePublisherProperties) error { + conn, err := p.proxy.getRemoteConnection(p.remoteUrl) + if err != nil { + return err + } + + if _, err := conn.RequestMessage(ctx, &signaling.ProxyClientMessage{ + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + ClientId: p.publisherId, + Hostname: p.proxy.remoteHostname, + Port: publisher.Port(), + RtcpPort: publisher.RtcpPort(), + }, + }); err != nil { + return err + } + + return nil +} + +func (p *proxyRemotePublisher) GetStreams(ctx context.Context) ([]signaling.PublisherStream, error) { + conn, err := p.proxy.getRemoteConnection(p.remoteUrl) + if err != nil { + return nil, err + } + + response, err := conn.RequestMessage(ctx, &signaling.ProxyClientMessage{ + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "get-publisher-streams", + ClientId: p.publisherId, + }, + }) + if err != nil { + return nil, err + } + + return response.Command.Streams, nil +} + func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { cmd := message.Command @@ -652,18 +844,89 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s case "create-subscriber": id := uuid.New().String() publisherId := cmd.PublisherId - subscriber, err := s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType) - if err == context.DeadlineExceeded { - log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId()) - session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber)) - return - } else if err != nil { + var subscriber signaling.McuSubscriber + var err error + + handleCreateError := func(err error) { + if err == context.DeadlineExceeded { + log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId()) + session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber)) + return + } else if errors.Is(err, signaling.ErrRemoteStreamsNotSupported) { + session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported)) + return + } + log.Printf("Error while creating %s subscriber on %s for %s: %s", cmd.StreamType, publisherId, session.PublicId(), err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) - return } - log.Printf("Created %s subscriber %s as %s for %s", cmd.StreamType, subscriber.Id(), id, session.PublicId()) + if cmd.RemoteUrl != "" { + if s.tokenId == "" || s.tokenKey == nil || s.remoteHostname == "" { + session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported)) + return + } + + remoteMcu, ok := s.mcu.(signaling.RemoteMcu) + if !ok { + session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported)) + return + } + + claims, _, err := s.parseToken(cmd.RemoteToken) + if err != nil { + if e, ok := err.(*signaling.Error); ok { + client.SendMessage(message.NewErrorServerMessage(e)) + } else { + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + } + return + } + + if claims.Subject != publisherId { + session.sendMessage(message.NewErrorServerMessage(TokenAuthFailed)) + return + } + + subCtx, cancel := context.WithTimeout(ctx, remotePublisherTimeout) + defer cancel() + + log.Printf("Creating remote subscriber for %s on %s", publisherId, cmd.RemoteUrl) + + controller := &proxyRemotePublisher{ + proxy: s, + remoteUrl: cmd.RemoteUrl, + publisherId: publisherId, + } + + var publisher signaling.McuRemotePublisher + publisher, err = remoteMcu.NewRemotePublisher(subCtx, session, controller, cmd.StreamType) + if err != nil { + handleCreateError(err) + return + } + + defer func() { + go publisher.Close(context.Background()) + }() + + subscriber, err = remoteMcu.NewRemoteSubscriber(subCtx, session, publisher) + if err != nil { + handleCreateError(err) + return + } + + log.Printf("Created remote %s subscriber %s as %s for %s on %s", cmd.StreamType, subscriber.Id(), id, session.PublicId(), cmd.RemoteUrl) + } else { + subscriber, err = s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType, &emptyInitiator{}) + if err != nil { + handleCreateError(err) + return + } + + log.Printf("Created %s subscriber %s as %s for %s", cmd.StreamType, subscriber.Id(), id, session.PublicId()) + } + session.StoreSubscriber(ctx, id, subscriber) s.StoreClient(id, subscriber) @@ -748,6 +1011,77 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s }, } session.sendMessage(response) + case "publish-remote": + client := s.GetClient(cmd.ClientId) + if client == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + publisher, ok := client.(signaling.McuPublisher) + if !ok { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + if err := publisher.PublishRemote(ctx, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { + var je *janus.ErrorMsg + if !errors.As(err, &je) || je.Err.Code != signaling.JANUS_VIDEOROOM_ERROR_ID_EXISTS { + log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + + if err := publisher.UnpublishRemote(ctx, session.PublicId()); err != nil { + log.Printf("Error unpublishing old %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + + if err := publisher.PublishRemote(ctx, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { + log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + } + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: cmd.ClientId, + }, + } + session.sendMessage(response) + case "get-publisher-streams": + client := s.GetClient(cmd.ClientId) + if client == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + publisher, ok := client.(signaling.McuPublisher) + if !ok { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + streams, err := publisher.GetStreams(ctx) + if err != nil { + log.Printf("Could not get streams of publisher %s: %s", publisher.Id(), err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: cmd.ClientId, + Streams: streams, + }, + } + session.sendMessage(response) default: log.Printf("Unsupported command %+v", message.Command) session.sendMessage(message.NewErrorServerMessage(UnsupportedCommand)) @@ -830,13 +1164,9 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s }) } -func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*ProxySession, error) { - if proxyDebugMessages { - log.Printf("Hello: %+v", hello) - } - +func (s *ProxyServer) parseToken(tokenValue string) (*signaling.TokenClaims, string, error) { reason := "auth-failed" - token, err := jwt.ParseWithClaims(hello.Token, &signaling.TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenValue, &signaling.TokenClaims{}, func(token *jwt.Token) (interface{}, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { log.Printf("Unexpected signing method: %v", token.Header["alg"]) @@ -868,25 +1198,35 @@ func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*Pro }) if err, ok := err.(*jwt.ValidationError); ok { if err.Errors&jwt.ValidationErrorIssuedAt == jwt.ValidationErrorIssuedAt { - statsTokenErrorsTotal.WithLabelValues("not-valid-yet").Inc() - return nil, TokenNotValidYet + return nil, "not-valid-yet", TokenNotValidYet } } if err != nil { - statsTokenErrorsTotal.WithLabelValues(reason).Inc() - return nil, TokenAuthFailed + return nil, reason, TokenAuthFailed } claims, ok := token.Claims.(*signaling.TokenClaims) if !ok || !token.Valid { - statsTokenErrorsTotal.WithLabelValues("auth-failed").Inc() - return nil, TokenAuthFailed + return nil, "auth-failed", TokenAuthFailed } minIssuedAt := time.Now().Add(-maxTokenAge) if issuedAt := claims.IssuedAt; issuedAt != nil && issuedAt.Before(minIssuedAt) { - statsTokenErrorsTotal.WithLabelValues("expired").Inc() - return nil, TokenExpired + return nil, "expired", TokenExpired + } + + return claims, "", nil +} + +func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*ProxySession, error) { + if proxyDebugMessages { + log.Printf("Hello: %+v", hello) + } + + claims, reason, err := s.parseToken(hello.Token) + if err != nil { + statsTokenErrorsTotal.WithLabelValues(reason).Inc() + return nil, err } sid := s.sid.Add(1) @@ -982,15 +1322,21 @@ func (s *ProxyServer) HasClients() bool { return len(s.clients) > 0 } -func (s *ProxyServer) GetClientsLoad() int64 { +func (s *ProxyServer) GetClientsLoad() (load int64, incoming int64, outgoing int64) { s.clientsLock.RLock() defer s.clientsLock.RUnlock() - var load int64 for _, c := range s.clients { - load += int64(c.MaxBitrate()) + bitrate := int64(c.MaxBitrate()) + load += bitrate + if _, ok := c.(signaling.McuPublisher); ok { + incoming += bitrate + } else if _, ok := c.(signaling.McuSubscriber); ok { + outgoing += bitrate + } } - return load / 1024 + load = load / 1024 + return } func (s *ProxyServer) GetClient(id string) signaling.McuClient { @@ -999,6 +1345,22 @@ func (s *ProxyServer) GetClient(id string) signaling.McuClient { return s.clients[id] } +func (s *ProxyServer) GetPublisher(publisherId string) signaling.McuPublisher { + s.clientsLock.RLock() + defer s.clientsLock.RUnlock() + for _, c := range s.clients { + pub, ok := c.(signaling.McuPublisher) + if !ok { + continue + } + + if pub.Id() == publisherId { + return pub + } + } + return nil +} + func (s *ProxyServer) GetClientId(client signaling.McuClient) string { s.clientsLock.RLock() defer s.clientsLock.RUnlock() @@ -1054,3 +1416,21 @@ func (s *ProxyServer) metricsHandler(w http.ResponseWriter, r *http.Request) { // Expose prometheus metrics at "/metrics". promhttp.Handler().ServeHTTP(w, r) } + +func (s *ProxyServer) getRemoteConnection(url string) (*RemoteConnection, error) { + s.remoteConnectionsLock.Lock() + defer s.remoteConnectionsLock.Unlock() + + conn, found := s.remoteConnections[url] + if found { + return conn, nil + } + + conn, err := NewRemoteConnection(url, s.tokenId, s.tokenKey, s.remoteTlsConfig) + if err != nil { + return nil, err + } + + s.remoteConnections[url] = conn + return conn, nil +} diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index 7ed87df..25a9a57 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -26,6 +26,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "net" "os" "testing" "time" @@ -92,6 +93,92 @@ func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey) { return server, key } +func TestTokenValid(t *testing.T) { + signaling.CatchLogForTest(t) + server, key := newProxyServerForTest(t) + + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + if err != nil { + t.Fatalf("could not create token: %s", err) + } + + hello := &signaling.HelloProxyClientMessage{ + Version: "1.0", + Token: tokenString, + } + session, err := server.NewSession(hello) + if session != nil { + defer session.Close() + } else if err != nil { + t.Error(err) + } +} + +func TestTokenNotSigned(t *testing.T) { + signaling.CatchLogForTest(t) + server, _ := newProxyServerForTest(t) + + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodNone, claims) + tokenString, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType) + if err != nil { + t.Fatalf("could not create token: %s", err) + } + + hello := &signaling.HelloProxyClientMessage{ + Version: "1.0", + Token: tokenString, + } + session, err := server.NewSession(hello) + if session != nil { + defer session.Close() + t.Errorf("should not have created session") + } else if err != TokenAuthFailed { + t.Errorf("could have failed with TokenAuthFailed, got %s", err) + } +} + +func TestTokenUnknown(t *testing.T) { + signaling.CatchLogForTest(t) + server, key := newProxyServerForTest(t) + + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest + "2", + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + if err != nil { + t.Fatalf("could not create token: %s", err) + } + + hello := &signaling.HelloProxyClientMessage{ + Version: "1.0", + Token: tokenString, + } + session, err := server.NewSession(hello) + if session != nil { + defer session.Close() + t.Errorf("should not have created session") + } else if err != TokenAuthFailed { + t.Errorf("could have failed with TokenAuthFailed, got %s", err) + } +} + func TestTokenInFuture(t *testing.T) { signaling.CatchLogForTest(t) server, key := newProxyServerForTest(t) @@ -120,3 +207,67 @@ func TestTokenInFuture(t *testing.T) { t.Errorf("could have failed with TokenNotValidYet, got %s", err) } } + +func TestTokenExpired(t *testing.T) { + signaling.CatchLogForTest(t) + server, key := newProxyServerForTest(t) + + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge * 2)), + Issuer: TokenIdForTest, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + if err != nil { + t.Fatalf("could not create token: %s", err) + } + + hello := &signaling.HelloProxyClientMessage{ + Version: "1.0", + Token: tokenString, + } + session, err := server.NewSession(hello) + if session != nil { + defer session.Close() + t.Errorf("should not have created session") + } else if err != TokenExpired { + t.Errorf("could have failed with TokenExpired, got %s", err) + } +} + +func TestPublicIPs(t *testing.T) { + public := []string{ + "8.8.8.8", + "172.15.1.2", + "172.32.1.2", + "192.167.0.1", + "192.169.0.1", + } + private := []string{ + "127.0.0.1", + "10.1.2.3", + "172.16.1.2", + "172.31.1.2", + "192.168.0.1", + "192.168.254.254", + } + for _, s := range public { + ip := net.ParseIP(s) + if len(ip) == 0 { + t.Errorf("invalid IP: %s", s) + } else if !IsPublicIP(ip) { + t.Errorf("should be public IP: %s", s) + } + } + + for _, s := range private { + ip := net.ParseIP(s) + if len(ip) == 0 { + t.Errorf("invalid IP: %s", s) + } else if IsPublicIP(ip) { + t.Errorf("should be private IP: %s", s) + } + } +} diff --git a/publisher_stats_counter.go b/publisher_stats_counter.go new file mode 100644 index 0000000..ba8b293 --- /dev/null +++ b/publisher_stats_counter.go @@ -0,0 +1,99 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2021 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "sync" +) + +type publisherStatsCounter struct { + mu sync.Mutex + + streamTypes map[StreamType]bool + subscribers map[string]bool +} + +func (c *publisherStatsCounter) Reset() { + c.mu.Lock() + defer c.mu.Unlock() + + count := len(c.subscribers) + for streamType := range c.streamTypes { + statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(count)) + } + c.streamTypes = nil + c.subscribers = nil +} + +func (c *publisherStatsCounter) EnableStream(streamType StreamType, enable bool) { + c.mu.Lock() + defer c.mu.Unlock() + + if enable == c.streamTypes[streamType] { + return + } + + if enable { + if c.streamTypes == nil { + c.streamTypes = make(map[StreamType]bool) + } + c.streamTypes[streamType] = true + statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Inc() + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Add(float64(len(c.subscribers))) + } else { + delete(c.streamTypes, streamType) + statsMcuPublisherStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Sub(float64(len(c.subscribers))) + } +} + +func (c *publisherStatsCounter) AddSubscriber(id string) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.subscribers[id] { + return + } + + if c.subscribers == nil { + c.subscribers = make(map[string]bool) + } + c.subscribers[id] = true + for streamType := range c.streamTypes { + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Inc() + } +} + +func (c *publisherStatsCounter) RemoveSubscriber(id string) { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.subscribers[id] { + return + } + + delete(c.subscribers, id) + for streamType := range c.streamTypes { + statsMcuSubscriberStreamTypesCurrent.WithLabelValues(string(streamType)).Dec() + } +} diff --git a/mcu_janus_test.go b/publisher_stats_counter_test.go similarity index 100% rename from mcu_janus_test.go rename to publisher_stats_counter_test.go diff --git a/server/main.go b/server/main.go index 9ee6afd..a31a0f5 100644 --- a/server/main.go +++ b/server/main.go @@ -22,6 +22,7 @@ package main import ( + "context" "crypto/tls" "errors" "flag" @@ -240,9 +241,11 @@ func main() { mcuRetryTimer := time.NewTimer(mcuRetry) mcuTypeLoop: for { + // Context should be cancelled on signals but need a way to differentiate later. + ctx := context.TODO() switch mcuType { case signaling.McuTypeJanus: - mcu, err = signaling.NewMcuJanus(mcuUrl, config) + mcu, err = signaling.NewMcuJanus(ctx, mcuUrl, config) signaling.UnregisterProxyMcuStats() signaling.RegisterJanusMcuStats() case signaling.McuTypeProxy: @@ -253,7 +256,7 @@ func main() { log.Fatal("Unsupported MCU type: ", mcuType) } if err == nil { - err = mcu.Start() + err = mcu.Start(ctx) if err != nil { log.Printf("Could not create %s MCU: %s", mcuType, err) } diff --git a/slices_go120.go b/slices_go120.go new file mode 100644 index 0000000..de80826 --- /dev/null +++ b/slices_go120.go @@ -0,0 +1,34 @@ +//go:build !go1.21 + +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "sort" +) + +func SlicesSortFunc[T any](l []T, f func(a T, b T) int) { + sort.Slice(l, func(i, j int) bool { + return f(l[i], l[j]) < 0 + }) +} diff --git a/slices_go121.go b/slices_go121.go new file mode 100644 index 0000000..bc41535 --- /dev/null +++ b/slices_go121.go @@ -0,0 +1,32 @@ +//go:build go1.21 + +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "slices" +) + +func SlicesSortFunc[T any](l []T, f func(a T, b T) int) { + slices.SortFunc(l, f) +}