Merge pull request #860 from strukturag/unpublish-remote

Notify remote to stop publishing when last local subscriber is closed.
This commit is contained in:
Joachim Bauch 2024-11-11 11:47:45 +01:00 committed by GitHub
commit 8a53bab9cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1624 additions and 20 deletions

View file

@ -219,6 +219,17 @@ type dummyGatewayListener struct {
func (l *dummyGatewayListener) ConnectionInterrupted() {
}
type JanusGatewayInterface interface {
Info(context.Context) (*InfoMsg, error)
Create(context.Context) (*JanusSession, error)
Close() error
send(map[string]interface{}, *transaction) (uint64, error)
removeTransaction(uint64)
removeSession(*JanusSession)
}
// Gateway represents a connection to an instance of the Janus Gateway.
type JanusGateway struct {
listener GatewayListener
@ -560,12 +571,18 @@ func (gateway *JanusGateway) Create(ctx context.Context) (*JanusSession, error)
// Store this session
gateway.Lock()
defer gateway.Unlock()
gateway.Sessions[session.Id] = session
gateway.Unlock()
return session, nil
}
func (gateway *JanusGateway) removeSession(session *JanusSession) {
gateway.Lock()
defer gateway.Unlock()
delete(gateway.Sessions, session.Id)
}
// Session represents a session instance on the Janus Gateway.
type JanusSession struct {
// Id is the session_id of this session
@ -578,7 +595,7 @@ type JanusSession struct {
// and Session.Unlock() methods provided by the embedded sync.Mutex.
sync.Mutex
gateway *JanusGateway
gateway JanusGatewayInterface
}
func (session *JanusSession) send(msg map[string]interface{}, t *transaction) (uint64, error) {
@ -670,9 +687,7 @@ func (session *JanusSession) Destroy(ctx context.Context) (*janus.AckMsg, error)
}
// Remove this session from the gateway
session.gateway.Lock()
delete(session.gateway.Sessions, session.Id)
session.gateway.Unlock()
session.gateway.removeSession(session)
return ack, nil
}

View file

@ -166,6 +166,7 @@ type RemotePublisherController interface {
PublisherId() string
StartPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error
StopPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error
GetStreams(ctx context.Context) ([]PublisherStream, error)
}
@ -214,7 +215,7 @@ type McuPublisher interface {
GetStreams(ctx context.Context) ([]PublisherStream, error)
PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error
UnpublishRemote(ctx context.Context, remoteId string) error
UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error
}
type McuSubscriber interface {

View file

@ -78,6 +78,11 @@ func convertIntValue(value interface{}) (uint64, error) {
return uint64(t), nil
case uint64:
return t, nil
case int:
if t < 0 {
return 0, fmt.Errorf("Unsupported int number: %+v", t)
}
return uint64(t), nil
case int64:
if t < 0 {
return 0, fmt.Errorf("Unsupported int64 number: %+v", t)
@ -92,7 +97,7 @@ func convertIntValue(value interface{}) (uint64, error) {
}
return uint64(r), nil
default:
return 0, fmt.Errorf("Unknown number type: %+v", t)
return 0, fmt.Errorf("Unknown number type: %+v (%T)", t, t)
}
}
@ -170,7 +175,9 @@ type mcuJanus struct {
settings McuSettings
gw *JanusGateway
createJanusGateway func(ctx context.Context, wsURL string, listener GatewayListener) (JanusGatewayInterface, error)
gw JanusGatewayInterface
session *JanusSession
handle *JanusHandle
@ -213,6 +220,9 @@ func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mc
publishers: make(map[string]*mcuJanusPublisher),
remotePublishers: make(map[string]*mcuJanusRemotePublisher),
createJanusGateway: func(ctx context.Context, wsURL string, listener GatewayListener) (JanusGatewayInterface, error) {
return NewJanusGateway(ctx, wsURL, listener)
},
reconnectInterval: initialReconnectInterval,
}
mcu.onConnected.Store(emptyOnConnected)
@ -222,8 +232,10 @@ func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mc
mcu.doReconnect(context.Background())
})
mcu.reconnectTimer.Stop()
if err := mcu.reconnect(ctx); err != nil {
return nil, err
if mcu.url != "" {
if err := mcu.reconnect(ctx); err != nil {
return nil, err
}
}
return mcu, nil
}
@ -252,7 +264,7 @@ func (m *mcuJanus) disconnect() {
func (m *mcuJanus) reconnect(ctx context.Context) error {
m.disconnect()
gw, err := NewJanusGateway(ctx, m.url, m)
gw, err := m.createJanusGateway(ctx, m.url, m)
if err != nil {
return err
}
@ -317,6 +329,11 @@ func (m *mcuJanus) hasRemotePublisher() bool {
}
func (m *mcuJanus) Start(ctx context.Context) error {
if m.url == "" {
if err := m.reconnect(ctx); err != nil {
return err
}
}
info, err := m.gw.Info(ctx)
if err != nil {
return err
@ -785,6 +802,8 @@ func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller Re
settings: settings,
},
controller: controller,
port: int(port),
rtcpPort: int(rtcp_port),
}

View file

@ -380,8 +380,8 @@ func (p *mcuJanusPublisher) GetStreams(ctx context.Context) ([]PublisherStream,
return streams, nil
}
func getPublisherRemoteId(id string, remoteId string) string {
return fmt.Sprintf("%s@%s", id, remoteId)
func getPublisherRemoteId(id string, remoteId string, hostname string, port int, rtcpPort int) string {
return fmt.Sprintf("%s-%s@%s:%d:%d", id, remoteId, hostname, port, rtcpPort)
}
func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error {
@ -389,7 +389,7 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId string,
"request": "publish_remotely",
"room": p.roomId,
"publisher_id": streamTypeUserIds[p.streamType],
"remote_id": getPublisherRemoteId(p.id, remoteId),
"remote_id": getPublisherRemoteId(p.id, remoteId, hostname, port, rtcpPort),
"host": hostname,
"port": port,
"rtcp_port": rtcpPort,
@ -421,12 +421,12 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId string,
return nil
}
func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId string) error {
func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error {
msg := map[string]interface{}{
"request": "unpublish_remotely",
"room": p.roomId,
"publisher_id": streamTypeUserIds[p.streamType],
"remote_id": getPublisherRemoteId(p.id, remoteId),
"remote_id": getPublisherRemoteId(p.id, remoteId, hostname, port, rtcpPort),
}
response, err := p.handle.Request(ctx, msg)
if err != nil {

View file

@ -34,6 +34,8 @@ type mcuJanusRemotePublisher struct {
ref atomic.Int64
controller RemotePublisherController
port int
rtcpPort int
}
@ -116,6 +118,10 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) {
return
}
if err := p.controller.StopPublishing(ctx, p); err != nil {
log.Printf("Error stopping remote publisher %s in room %d: %s", p.id, p.roomId, err)
}
p.mu.Lock()
if handle := p.handle; handle != nil {
response, err := p.handle.Request(ctx, map[string]interface{}{

584
mcu_janus_test.go Normal file
View file

@ -0,0 +1,584 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"encoding/json"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/dlintw/goconf"
"github.com/notedit/janus-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type TestJanusHandle struct {
id uint64
}
type TestJanusRoom struct {
id uint64
}
type TestJanusHandler func(room *TestJanusRoom, body map[string]interface{}) (interface{}, *janus.ErrorMsg)
type TestJanusGateway struct {
t *testing.T
sid atomic.Uint64
tid atomic.Uint64
hid atomic.Uint64
rid atomic.Uint64
mu sync.Mutex
sessions map[uint64]*JanusSession
transactions map[uint64]*transaction
handles map[uint64]*TestJanusHandle
rooms map[uint64]*TestJanusRoom
handlers map[string]TestJanusHandler
}
func NewTestJanusGateway(t *testing.T) *TestJanusGateway {
gateway := &TestJanusGateway{
t: t,
sessions: make(map[uint64]*JanusSession),
transactions: make(map[uint64]*transaction),
handles: make(map[uint64]*TestJanusHandle),
rooms: make(map[uint64]*TestJanusRoom),
handlers: make(map[string]TestJanusHandler),
}
t.Cleanup(func() {
assert := assert.New(t)
gateway.mu.Lock()
defer gateway.mu.Unlock()
assert.Len(gateway.sessions, 0)
assert.Len(gateway.transactions, 0)
assert.Len(gateway.handles, 0)
assert.Len(gateway.rooms, 0)
})
return gateway
}
func (g *TestJanusGateway) registerHandlers(handlers map[string]TestJanusHandler) {
g.mu.Lock()
defer g.mu.Unlock()
for name, handler := range handlers {
g.handlers[name] = handler
}
}
func (g *TestJanusGateway) Info(ctx context.Context) (*InfoMsg, error) {
return &InfoMsg{
Name: "TestJanus",
Version: 1400,
VersionString: "1.4.0",
Author: "struktur AG",
DataChannels: true,
FullTrickle: true,
Plugins: map[string]janus.PluginInfo{
pluginVideoRoom: {
Name: "Test VideoRoom plugin",
VersionString: "0.0.0",
Author: "struktur AG",
},
},
}, nil
}
func (g *TestJanusGateway) Create(ctx context.Context) (*JanusSession, error) {
sid := g.sid.Add(1)
session := &JanusSession{
Id: sid,
Handles: make(map[uint64]*JanusHandle),
gateway: g,
}
g.mu.Lock()
defer g.mu.Unlock()
g.sessions[sid] = session
return session, nil
}
func (g *TestJanusGateway) Close() error {
return nil
}
func (g *TestJanusGateway) processMessage(session *JanusSession, handle *TestJanusHandle, body map[string]interface{}) interface{} {
request := body["request"].(string)
switch request {
case "create":
room := &TestJanusRoom{
id: g.rid.Add(1),
}
g.rooms[room.id] = room
return &janus.SuccessMsg{
PluginData: janus.PluginData{
Plugin: pluginVideoRoom,
Data: map[string]interface{}{
"room": room.id,
},
},
}
case "join":
rid := body["room"].(float64)
room := g.rooms[uint64(rid)]
if room == nil {
return &janus.ErrorMsg{
Err: janus.ErrorData{
Code: JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM,
Reason: "Room not found",
},
}
}
assert.Equal(g.t, "publisher", body["ptype"])
return &janus.EventMsg{
Session: session.Id,
Handle: handle.id,
Plugindata: janus.PluginData{
Plugin: pluginVideoRoom,
Data: map[string]interface{}{
"room": room.id,
},
},
}
case "destroy":
rid := body["room"].(float64)
room := g.rooms[uint64(rid)]
if room == nil {
return &janus.ErrorMsg{
Err: janus.ErrorData{
Code: JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM,
Reason: "Room not found",
},
}
}
delete(g.rooms, uint64(rid))
return &janus.SuccessMsg{
PluginData: janus.PluginData{
Plugin: pluginVideoRoom,
Data: map[string]interface{}{},
},
}
default:
rid := body["room"].(float64)
room := g.rooms[uint64(rid)]
if room == nil {
return &janus.ErrorMsg{
Err: janus.ErrorData{
Code: JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM,
Reason: "Room not found",
},
}
}
handler, found := g.handlers[request]
if found {
var err *janus.ErrorMsg
result, err := handler(room, body)
if err != nil {
result = err
}
return result
}
}
return nil
}
func (g *TestJanusGateway) processRequest(msg map[string]interface{}) interface{} {
method, found := msg["janus"]
if !found {
return nil
}
sid := msg["session_id"].(float64)
g.mu.Lock()
defer g.mu.Unlock()
session := g.sessions[uint64(sid)]
if session == nil {
return &janus.ErrorMsg{
Err: janus.ErrorData{
Code: JANUS_ERROR_SESSION_NOT_FOUND,
Reason: "Session not found",
},
}
}
switch method {
case "attach":
handle := &TestJanusHandle{
id: g.hid.Add(1),
}
g.handles[handle.id] = handle
return &janus.SuccessMsg{
Data: janus.SuccessData{
ID: handle.id,
},
}
case "detach":
hid := msg["handle_id"].(float64)
handle, found := g.handles[uint64(hid)]
if found {
delete(g.handles, handle.id)
}
if handle == nil {
return &janus.ErrorMsg{
Err: janus.ErrorData{
Code: JANUS_ERROR_HANDLE_NOT_FOUND,
Reason: "Handle not found",
},
}
}
return &janus.AckMsg{}
case "destroy":
delete(g.sessions, session.Id)
return &janus.AckMsg{}
case "message":
hid := msg["handle_id"].(float64)
handle, found := g.handles[uint64(hid)]
if !found {
return &janus.ErrorMsg{
Err: janus.ErrorData{
Code: JANUS_ERROR_HANDLE_NOT_FOUND,
Reason: "Handle not found",
},
}
}
body := msg["body"].(map[string]interface{})
return g.processMessage(session, handle, body)
}
return nil
}
func (g *TestJanusGateway) send(msg map[string]interface{}, t *transaction) (uint64, error) {
tid := g.tid.Add(1)
data, err := json.Marshal(msg)
require.NoError(g.t, err)
err = json.Unmarshal(data, &msg)
require.NoError(g.t, err)
go t.run()
g.mu.Lock()
defer g.mu.Unlock()
g.transactions[tid] = t
go func() {
result := g.processRequest(msg)
if !assert.NotNil(g.t, result, "Unsupported request %+v", msg) {
result = &janus.ErrorMsg{
Err: janus.ErrorData{
Code: JANUS_ERROR_UNKNOWN,
Reason: "Not implemented",
},
}
}
t.add(result)
}()
return tid, nil
}
func (g *TestJanusGateway) removeTransaction(id uint64) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.transactions, id)
}
func (g *TestJanusGateway) removeSession(session *JanusSession) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.sessions, session.Id)
}
func newMcuJanusForTesting(t *testing.T) (*mcuJanus, *TestJanusGateway) {
gateway := NewTestJanusGateway(t)
config := goconf.NewConfigFile()
mcu, err := NewMcuJanus(context.Background(), "", config)
require.NoError(t, err)
t.Cleanup(func() {
mcu.Stop()
})
mcuJanus := mcu.(*mcuJanus)
mcuJanus.createJanusGateway = func(ctx context.Context, wsURL string, listener GatewayListener) (JanusGatewayInterface, error) {
return gateway, nil
}
require.NoError(t, mcu.Start(context.Background()))
return mcuJanus, gateway
}
type TestMcuListener struct {
id string
}
func (t *TestMcuListener) PublicId() string {
return t.id
}
func (t *TestMcuListener) OnUpdateOffer(client McuClient, offer map[string]interface{}) {
}
func (t *TestMcuListener) OnIceCandidate(client McuClient, candidate interface{}) {
}
func (t *TestMcuListener) OnIceCompleted(client McuClient) {
}
func (t *TestMcuListener) SubscriberSidUpdated(subscriber McuSubscriber) {
}
func (t *TestMcuListener) PublisherClosed(publisher McuPublisher) {
}
func (t *TestMcuListener) SubscriberClosed(subscriber McuSubscriber) {
}
type TestMcuController struct {
id string
}
func (c *TestMcuController) PublisherId() string {
return c.id
}
func (c *TestMcuController) StartPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error {
// TODO: Check parameters?
return nil
}
func (c *TestMcuController) StopPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error {
// TODO: Check parameters?
return nil
}
func (c *TestMcuController) GetStreams(ctx context.Context) ([]PublisherStream, error) {
streams := []PublisherStream{
{
Mid: "0",
Mindex: 0,
Type: "audio",
Codec: "opus",
},
}
return streams, nil
}
type TestMcuInitiator struct {
country string
}
func (i *TestMcuInitiator) Country() string {
return i.country
}
func Test_JanusPublisherSubscriber(t *testing.T) {
CatchLogForTest(t)
t.Parallel()
require := require.New(t)
mcu, gateway := newMcuJanusForTesting(t)
gateway.registerHandlers(map[string]TestJanusHandler{})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
pubId := "publisher-id"
listener1 := &TestMcuListener{
id: pubId,
}
settings1 := NewPublisherSettings{}
initiator1 := &TestMcuInitiator{
country: "DE",
}
pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", StreamTypeVideo, settings1, initiator1)
require.NoError(err)
defer pub.Close(context.Background())
listener2 := &TestMcuListener{
id: pubId,
}
initiator2 := &TestMcuInitiator{
country: "DE",
}
sub, err := mcu.NewSubscriber(ctx, listener2, pubId, StreamTypeVideo, initiator2)
require.NoError(err)
defer sub.Close(context.Background())
}
func Test_JanusSubscriberPublisher(t *testing.T) {
CatchLogForTest(t)
t.Parallel()
require := require.New(t)
mcu, gateway := newMcuJanusForTesting(t)
gateway.registerHandlers(map[string]TestJanusHandler{})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
pubId := "publisher-id"
listener1 := &TestMcuListener{
id: pubId,
}
settings1 := NewPublisherSettings{}
initiator1 := &TestMcuInitiator{
country: "DE",
}
ready := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
time.Sleep(100 * time.Millisecond)
pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", StreamTypeVideo, settings1, initiator1)
require.NoError(err)
defer func() {
<-ready
pub.Close(context.Background())
}()
}()
listener2 := &TestMcuListener{
id: pubId,
}
initiator2 := &TestMcuInitiator{
country: "DE",
}
sub, err := mcu.NewSubscriber(ctx, listener2, pubId, StreamTypeVideo, initiator2)
require.NoError(err)
defer sub.Close(context.Background())
close(ready)
<-done
}
func Test_JanusRemotePublisher(t *testing.T) {
CatchLogForTest(t)
t.Parallel()
assert := assert.New(t)
require := require.New(t)
var added atomic.Int32
var removed atomic.Int32
mcu, gateway := newMcuJanusForTesting(t)
gateway.registerHandlers(map[string]TestJanusHandler{
"add_remote_publisher": func(room *TestJanusRoom, body map[string]interface{}) (interface{}, *janus.ErrorMsg) {
assert.EqualValues(1, room.id)
if streams := body["streams"].([]interface{}); assert.Len(streams, 1) {
stream := streams[0].(map[string]interface{})
assert.Equal("0", stream["mid"])
assert.EqualValues(0, stream["mindex"])
assert.Equal("audio", stream["type"])
assert.Equal("opus", stream["codec"])
}
added.Add(1)
return &janus.SuccessMsg{
PluginData: janus.PluginData{
Plugin: pluginVideoRoom,
Data: map[string]interface{}{
"id": 12345,
"port": 10000,
"rtcp_port": 10001,
},
},
}, nil
},
"remove_remote_publisher": func(room *TestJanusRoom, body map[string]interface{}) (interface{}, *janus.ErrorMsg) {
assert.EqualValues(1, room.id)
removed.Add(1)
return &janus.SuccessMsg{
PluginData: janus.PluginData{
Plugin: pluginVideoRoom,
Data: map[string]interface{}{},
},
}, nil
},
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
listener1 := &TestMcuListener{
id: "publisher-id",
}
controller := &TestMcuController{
id: listener1.id,
}
pub, err := mcu.NewRemotePublisher(ctx, listener1, controller, StreamTypeVideo)
require.NoError(err)
defer pub.Close(context.Background())
assert.EqualValues(1, added.Load())
assert.EqualValues(0, removed.Load())
listener2 := &TestMcuListener{
id: "subscriber-id",
}
sub, err := mcu.NewRemoteSubscriber(ctx, listener2, pub)
require.NoError(err)
defer sub.Close(context.Background())
pub.Close(context.Background())
assert.EqualValues(1, added.Load())
// The publisher is ref-counted, and still referenced by the subscriber.
assert.EqualValues(0, removed.Load())
sub.Close(context.Background())
assert.EqualValues(1, added.Load())
assert.EqualValues(1, removed.Load())
}

View file

@ -227,7 +227,7 @@ func (p *mcuProxyPublisher) PublishRemote(ctx context.Context, remoteId string,
return errors.New("remote publishing not supported for proxy publishers")
}
func (p *mcuProxyPublisher) UnpublishRemote(ctx context.Context, remoteId string) error {
func (p *mcuProxyPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error {
return errors.New("remote publishing not supported for proxy publishers")
}

View file

@ -1502,6 +1502,110 @@ func Test_ProxyRemotePublisher(t *testing.T) {
defer sub.Close(context.Background())
}
func Test_ProxyMultipleRemotePublisher(t *testing.T) {
CatchLogForTest(t)
t.Parallel()
etcd := NewEtcdForTest(t)
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer3, addr3 := NewGrpcServerForTest(t)
hub1 := &mockGrpcServerHub{}
hub2 := &mockGrpcServerHub{}
hub3 := &mockGrpcServerHub{}
grpcServer1.hub = hub1
grpcServer2.hub = hub2
grpcServer3.hub = hub3
SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
SetEtcdValue(etcd, "/grpctargets/three", []byte("{\"address\":\""+addr3+"\"}"))
server1 := NewProxyServerForTest(t, "DE")
server2 := NewProxyServerForTest(t, "US")
server3 := NewProxyServerForTest(t, "US")
mcu1 := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
servers: []*TestProxyServerHandler{
server1,
server2,
server3,
},
})
mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
servers: []*TestProxyServerHandler{
server1,
server2,
server3,
},
})
mcu3 := newMcuProxyForTestWithOptions(t, proxyTestOptions{
etcd: etcd,
servers: []*TestProxyServerHandler{
server1,
server2,
server3,
},
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
pubId := "the-publisher"
pubSid := "1234567890"
pubListener := &MockMcuListener{
publicId: pubId + "-public",
}
pubInitiator := &MockMcuInitiator{
country: "DE",
}
session1 := &ClientSession{
publicId: pubId,
publishers: make(map[StreamType]McuPublisher),
}
hub1.addSession(session1)
defer hub1.removeSession(session1)
pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, NewPublisherSettings{
MediaTypes: MediaTypeVideo | MediaTypeAudio,
}, pubInitiator)
require.NoError(t, err)
defer pub.Close(context.Background())
session1.mu.Lock()
session1.publishers[StreamTypeVideo] = pub
session1.publisherWaiters.Wakeup()
session1.mu.Unlock()
sub1Listener := &MockMcuListener{
publicId: "subscriber-public-1",
}
sub1Initiator := &MockMcuInitiator{
country: "US",
}
sub1, err := mcu2.NewSubscriber(ctx, sub1Listener, pubId, StreamTypeVideo, sub1Initiator)
require.NoError(t, err)
defer sub1.Close(context.Background())
sub2Listener := &MockMcuListener{
publicId: "subscriber-public-2",
}
sub2Initiator := &MockMcuInitiator{
country: "US",
}
sub2, err := mcu3.NewSubscriber(ctx, sub2Listener, pubId, StreamTypeVideo, sub2Initiator)
require.NoError(t, err)
defer sub2.Close(context.Background())
}
func Test_ProxyRemotePublisherWait(t *testing.T) {
CatchLogForTest(t)
t.Parallel()

View file

@ -229,7 +229,7 @@ func (p *TestMCUPublisher) PublishRemote(ctx context.Context, remoteId string, h
return errors.New("remote publishing not supported")
}
func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string) error {
func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error {
return errors.New("remote publishing not supported")
}

View file

@ -856,6 +856,28 @@ func (p *proxyRemotePublisher) StartPublishing(ctx context.Context, publisher si
return nil
}
func (p *proxyRemotePublisher) StopPublishing(ctx context.Context, publisher signaling.McuRemotePublisherProperties) error {
conn, err := p.proxy.getRemoteConnection(p.remoteUrl)
if err != nil {
return err
}
if _, err := conn.RequestMessage(ctx, &signaling.ProxyClientMessage{
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "unpublish-remote",
ClientId: p.publisherId,
Hostname: p.proxy.remoteHostname,
Port: publisher.Port(),
RtcpPort: publisher.RtcpPort(),
},
}); err != nil {
return err
}
return nil
}
func (p *proxyRemotePublisher) GetStreams(ctx context.Context) ([]signaling.PublisherStream, error) {
conn, err := p.proxy.getRemoteConnection(p.remoteUrl)
if err != nil {
@ -1125,7 +1147,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
ctx2, cancel = context.WithTimeout(ctx, s.mcuTimeout)
defer cancel()
if err := publisher.UnpublishRemote(ctx2, session.PublicId()); err != nil {
if err := publisher.UnpublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil {
log.Printf("Error unpublishing old %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err)
session.sendMessage(message.NewWrappedErrorServerMessage(err))
return
@ -1141,6 +1163,39 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
}
}
session.AddRemotePublisher(publisher, cmd.Hostname, cmd.Port, cmd.RtcpPort)
response := &signaling.ProxyServerMessage{
Id: message.Id,
Type: "command",
Command: &signaling.CommandProxyServerMessage{
Id: cmd.ClientId,
},
}
session.sendMessage(response)
case "unpublish-remote":
client := s.GetClient(cmd.ClientId)
if client == nil {
session.sendMessage(message.NewErrorServerMessage(UnknownClient))
return
}
publisher, ok := client.(signaling.McuPublisher)
if !ok {
session.sendMessage(message.NewErrorServerMessage(UnknownClient))
return
}
ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout)
defer cancel()
if err := publisher.UnpublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil {
log.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, err)
session.sendMessage(message.NewWrappedErrorServerMessage(err))
return
}
session.RemoveRemotePublisher(publisher, cmd.Hostname, cmd.Port, cmd.RtcpPort)
response := &signaling.ProxyServerMessage{
Id: message.Id,
Type: "command",
@ -1547,3 +1602,12 @@ func (s *ProxyServer) getRemoteConnection(url string) (*RemoteConnection, error)
s.remoteConnections[url] = conn
return conn, nil
}
func (s *ProxyServer) PublisherDeleted(publisher signaling.McuPublisher) {
s.sessionsLock.RLock()
defer s.sessionsLock.RUnlock()
for _, session := range s.sessions {
session.OnPublisherDeleted(publisher)
}
}

View file

@ -33,6 +33,7 @@ import (
"net/http/httptest"
"os"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@ -431,7 +432,7 @@ func (p *TestMCUPublisher) PublishRemote(ctx context.Context, remoteId string, h
return errors.New("not implemented")
}
func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string) error {
func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error {
return errors.New("not implemented")
}
@ -618,3 +619,730 @@ func TestProxyCodecs(t *testing.T) {
}
}
}
type RemoteSubscriberTestMCU struct {
TestMCU
publisher *TestRemotePublisher
subscriber *TestRemoteSubscriber
}
func NewRemoteSubscriberTestMCU(t *testing.T) *RemoteSubscriberTestMCU {
return &RemoteSubscriberTestMCU{
TestMCU: TestMCU{
t: t,
},
}
}
type TestRemotePublisher struct {
t *testing.T
streamType signaling.StreamType
refcnt atomic.Int32
closed context.Context
closeFunc context.CancelFunc
}
func (p *TestRemotePublisher) Id() string {
return "id"
}
func (p *TestRemotePublisher) Sid() string {
return "sid"
}
func (p *TestRemotePublisher) StreamType() signaling.StreamType {
return p.streamType
}
func (p *TestRemotePublisher) MaxBitrate() int {
return 0
}
func (p *TestRemotePublisher) Close(ctx context.Context) {
if count := p.refcnt.Add(-1); assert.True(p.t, count >= 0) && count == 0 {
p.closeFunc()
}
}
func (p *TestRemotePublisher) SendMessage(ctx context.Context, message *signaling.MessageClientMessage, data *signaling.MessageClientMessageData, callback func(error, map[string]interface{})) {
callback(errors.New("not implemented"), nil)
}
func (p *TestRemotePublisher) Port() int {
return 1
}
func (p *TestRemotePublisher) RtcpPort() int {
return 2
}
func (m *RemoteSubscriberTestMCU) NewRemotePublisher(ctx context.Context, listener signaling.McuListener, controller signaling.RemotePublisherController, streamType signaling.StreamType) (signaling.McuRemotePublisher, error) {
require.Nil(m.t, m.publisher)
assert.EqualValues(m.t, "video", streamType)
closeCtx, closeFunc := context.WithCancel(context.Background())
m.publisher = &TestRemotePublisher{
t: m.t,
streamType: streamType,
closed: closeCtx,
closeFunc: closeFunc,
}
m.publisher.refcnt.Add(1)
return m.publisher, nil
}
type TestRemoteSubscriber struct {
t *testing.T
publisher *TestRemotePublisher
closed context.Context
closeFunc context.CancelFunc
}
func (s *TestRemoteSubscriber) Id() string {
return "id"
}
func (s *TestRemoteSubscriber) Sid() string {
return "sid"
}
func (s *TestRemoteSubscriber) StreamType() signaling.StreamType {
return s.publisher.StreamType()
}
func (s *TestRemoteSubscriber) MaxBitrate() int {
return 0
}
func (s *TestRemoteSubscriber) Close(ctx context.Context) {
s.publisher.Close(ctx)
s.closeFunc()
}
func (s *TestRemoteSubscriber) SendMessage(ctx context.Context, message *signaling.MessageClientMessage, data *signaling.MessageClientMessageData, callback func(error, map[string]interface{})) {
callback(errors.New("not implemented"), nil)
}
func (s *TestRemoteSubscriber) Publisher() string {
return s.publisher.Id()
}
func (m *RemoteSubscriberTestMCU) NewRemoteSubscriber(ctx context.Context, listener signaling.McuListener, publisher signaling.McuRemotePublisher) (signaling.McuRemoteSubscriber, error) {
require.Nil(m.t, m.subscriber)
pub, ok := publisher.(*TestRemotePublisher)
require.True(m.t, ok)
closeCtx, closeFunc := context.WithCancel(context.Background())
m.subscriber = &TestRemoteSubscriber{
t: m.t,
publisher: pub,
closed: closeCtx,
closeFunc: closeFunc,
}
pub.refcnt.Add(1)
return m.subscriber, nil
}
func TestProxyRemoteSubscriber(t *testing.T) {
signaling.CatchLogForTest(t)
assert := assert.New(t)
require := require.New(t)
proxy, key, server := newProxyServerForTest(t)
mcu := NewRemoteSubscriberTestMCU(t)
proxy.mcu = mcu
// Unused but must be set so remote subscribing works
proxy.tokenId = "token"
proxy.tokenKey = key
proxy.remoteHostname = "test-hostname"
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client := NewProxyTestClient(ctx, t, server.URL)
defer client.CloseWithBye()
require.NoError(client.SendHello(key))
if hello, err := client.RunUntilHello(ctx); assert.NoError(err) {
assert.NotEmpty(hello.Hello.SessionId, "%+v", hello)
}
_, err := client.RunUntilLoad(ctx, 0)
assert.NoError(err)
publisherId := "the-publisher-id"
claims := &signaling.TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)),
Issuer: TokenIdForTest,
Subject: publisherId,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(key)
require.NoError(err)
require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{
Id: "2345",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "create-subscriber",
StreamType: signaling.StreamTypeVideo,
PublisherId: publisherId,
RemoteUrl: "https://remote-hostname",
RemoteToken: tokenString,
},
}))
var clientId string
if message, err := client.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("2345", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
clientId = message.Command.Id
}
}
require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{
Id: "3456",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "delete-subscriber",
ClientId: clientId,
},
}))
if message, err := client.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("3456", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
assert.Equal(clientId, message.Command.Id)
}
}
if assert.NotNil(mcu.publisher) && assert.NotNil(mcu.subscriber) {
select {
case <-mcu.subscriber.closed.Done():
case <-ctx.Done():
assert.Fail("subscriber was not closed")
}
select {
case <-mcu.publisher.closed.Done():
case <-ctx.Done():
assert.Fail("publisher was not closed")
}
}
}
func TestProxyCloseRemoteOnSessionClose(t *testing.T) {
signaling.CatchLogForTest(t)
assert := assert.New(t)
require := require.New(t)
proxy, key, server := newProxyServerForTest(t)
mcu := NewRemoteSubscriberTestMCU(t)
proxy.mcu = mcu
// Unused but must be set so remote subscribing works
proxy.tokenId = "token"
proxy.tokenKey = key
proxy.remoteHostname = "test-hostname"
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client := NewProxyTestClient(ctx, t, server.URL)
defer client.CloseWithBye()
require.NoError(client.SendHello(key))
if hello, err := client.RunUntilHello(ctx); assert.NoError(err) {
assert.NotEmpty(hello.Hello.SessionId, "%+v", hello)
}
_, err := client.RunUntilLoad(ctx, 0)
assert.NoError(err)
publisherId := "the-publisher-id"
claims := &signaling.TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)),
Issuer: TokenIdForTest,
Subject: publisherId,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(key)
require.NoError(err)
require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{
Id: "2345",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "create-subscriber",
StreamType: signaling.StreamTypeVideo,
PublisherId: publisherId,
RemoteUrl: "https://remote-hostname",
RemoteToken: tokenString,
},
}))
if message, err := client.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("2345", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
}
}
// Closing the session will cause any active remote publishers stop be stopped.
client.CloseWithBye()
if assert.NotNil(mcu.publisher) && assert.NotNil(mcu.subscriber) {
select {
case <-mcu.subscriber.closed.Done():
case <-ctx.Done():
assert.Fail("subscriber was not closed")
}
select {
case <-mcu.publisher.closed.Done():
case <-ctx.Done():
assert.Fail("publisher was not closed")
}
}
}
type UnpublishRemoteTestMCU struct {
TestMCU
publisher atomic.Pointer[UnpublishRemoteTestPublisher]
}
func NewUnpublishRemoteTestMCU(t *testing.T) *UnpublishRemoteTestMCU {
return &UnpublishRemoteTestMCU{
TestMCU: TestMCU{
t: t,
},
}
}
type UnpublishRemoteTestPublisher struct {
TestMCUPublisher
t *testing.T
mu sync.RWMutex
remoteId string
remoteData *remotePublisherData
}
func (m *UnpublishRemoteTestMCU) NewPublisher(ctx context.Context, listener signaling.McuListener, id string, sid string, streamType signaling.StreamType, settings signaling.NewPublisherSettings, initiator signaling.McuInitiator) (signaling.McuPublisher, error) {
publisher := &UnpublishRemoteTestPublisher{
TestMCUPublisher: TestMCUPublisher{
id: id,
sid: sid,
streamType: streamType,
},
t: m.t,
}
m.publisher.Store(publisher)
return publisher, nil
}
func (p *UnpublishRemoteTestPublisher) getRemoteId() string {
p.mu.RLock()
defer p.mu.RUnlock()
return p.remoteId
}
func (p *UnpublishRemoteTestPublisher) getRemoteData() *remotePublisherData {
p.mu.RLock()
defer p.mu.RUnlock()
return p.remoteData
}
func (p *UnpublishRemoteTestPublisher) clearRemote() {
p.mu.Lock()
defer p.mu.Unlock()
p.remoteId = ""
p.remoteData = nil
}
func (p *UnpublishRemoteTestPublisher) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error {
p.mu.Lock()
defer p.mu.Unlock()
if assert.Empty(p.t, p.remoteId) {
p.remoteId = remoteId
p.remoteData = &remotePublisherData{
hostname: hostname,
port: port,
rtcpPort: rtcpPort,
}
}
return nil
}
func (p *UnpublishRemoteTestPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error {
p.mu.Lock()
defer p.mu.Unlock()
assert.Equal(p.t, remoteId, p.remoteId)
if remoteData := p.remoteData; assert.NotNil(p.t, remoteData) &&
assert.Equal(p.t, remoteData.hostname, hostname) &&
assert.EqualValues(p.t, remoteData.port, port) &&
assert.EqualValues(p.t, remoteData.rtcpPort, rtcpPort) {
p.remoteId = ""
p.remoteData = nil
}
return nil
}
func TestProxyUnpublishRemote(t *testing.T) {
signaling.CatchLogForTest(t)
assert := assert.New(t)
require := require.New(t)
proxy, key, server := newProxyServerForTest(t)
mcu := NewUnpublishRemoteTestMCU(t)
proxy.mcu = mcu
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewProxyTestClient(ctx, t, server.URL)
defer client1.CloseWithBye()
require.NoError(client1.SendHello(key))
if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) {
assert.NotEmpty(hello.Hello.SessionId, "%+v", hello)
}
_, err := client1.RunUntilLoad(ctx, 0)
assert.NoError(err)
publisherId := "the-publisher-id"
require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{
Id: "2345",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "create-publisher",
PublisherId: publisherId,
Sid: "1234-abcd",
StreamType: signaling.StreamTypeVideo,
PublisherSettings: &signaling.NewPublisherSettings{
Bitrate: 1234567,
MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo,
},
},
}))
var clientId string
if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("2345", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
clientId = message.Command.Id
}
}
client2 := NewProxyTestClient(ctx, t, server.URL)
defer client2.CloseWithBye()
require.NoError(client2.SendHello(key))
hello2, err := client2.RunUntilHello(ctx)
if assert.NoError(err) {
assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2)
}
_, err = client2.RunUntilLoad(ctx, 0)
assert.NoError(err)
require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{
Id: "3456",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "publish-remote",
StreamType: signaling.StreamTypeVideo,
ClientId: clientId,
Hostname: "remote-host",
Port: 10001,
RtcpPort: 10002,
},
}))
if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("3456", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
}
}
if publisher := mcu.publisher.Load(); assert.NotNil(publisher) {
assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId())
if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) {
assert.Equal("remote-host", remoteData.hostname)
assert.EqualValues(10001, remoteData.port)
assert.EqualValues(10002, remoteData.rtcpPort)
}
}
require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{
Id: "4567",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "unpublish-remote",
StreamType: signaling.StreamTypeVideo,
ClientId: clientId,
Hostname: "remote-host",
Port: 10001,
RtcpPort: 10002,
},
}))
if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("4567", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
}
}
if publisher := mcu.publisher.Load(); assert.NotNil(publisher) {
assert.Empty(publisher.getRemoteId())
assert.Nil(publisher.getRemoteData())
}
}
func TestProxyUnpublishRemotePublisherClosed(t *testing.T) {
signaling.CatchLogForTest(t)
assert := assert.New(t)
require := require.New(t)
proxy, key, server := newProxyServerForTest(t)
mcu := NewUnpublishRemoteTestMCU(t)
proxy.mcu = mcu
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewProxyTestClient(ctx, t, server.URL)
defer client1.CloseWithBye()
require.NoError(client1.SendHello(key))
if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) {
assert.NotEmpty(hello.Hello.SessionId, "%+v", hello)
}
_, err := client1.RunUntilLoad(ctx, 0)
assert.NoError(err)
publisherId := "the-publisher-id"
require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{
Id: "2345",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "create-publisher",
PublisherId: publisherId,
Sid: "1234-abcd",
StreamType: signaling.StreamTypeVideo,
PublisherSettings: &signaling.NewPublisherSettings{
Bitrate: 1234567,
MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo,
},
},
}))
var clientId string
if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("2345", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
clientId = message.Command.Id
}
}
client2 := NewProxyTestClient(ctx, t, server.URL)
defer client2.CloseWithBye()
require.NoError(client2.SendHello(key))
hello2, err := client2.RunUntilHello(ctx)
if assert.NoError(err) {
assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2)
}
_, err = client2.RunUntilLoad(ctx, 0)
assert.NoError(err)
require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{
Id: "3456",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "publish-remote",
StreamType: signaling.StreamTypeVideo,
ClientId: clientId,
Hostname: "remote-host",
Port: 10001,
RtcpPort: 10002,
},
}))
if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("3456", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
}
}
if publisher := mcu.publisher.Load(); assert.NotNil(publisher) {
assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId())
if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) {
assert.Equal("remote-host", remoteData.hostname)
assert.EqualValues(10001, remoteData.port)
assert.EqualValues(10002, remoteData.rtcpPort)
}
}
require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{
Id: "4567",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "delete-publisher",
ClientId: clientId,
},
}))
if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("4567", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
}
}
// Remote publishing was not stopped explicitly...
if publisher := mcu.publisher.Load(); assert.NotNil(publisher) {
assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId())
if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) {
assert.Equal("remote-host", remoteData.hostname)
assert.EqualValues(10001, remoteData.port)
assert.EqualValues(10002, remoteData.rtcpPort)
}
}
// ...but the session no longer contains information on the remote publisher.
if data, err := proxy.cookie.DecodePublic(hello2.Hello.SessionId); assert.NoError(err) {
session := proxy.GetSession(data.Sid)
if assert.NotNil(session) {
session.remotePublishersLock.Lock()
defer session.remotePublishersLock.Unlock()
assert.Empty(session.remotePublishers)
}
}
if publisher := mcu.publisher.Load(); assert.NotNil(publisher) {
publisher.clearRemote()
}
}
func TestProxyUnpublishRemoteOnSessionClose(t *testing.T) {
signaling.CatchLogForTest(t)
assert := assert.New(t)
require := require.New(t)
proxy, key, server := newProxyServerForTest(t)
mcu := NewUnpublishRemoteTestMCU(t)
proxy.mcu = mcu
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewProxyTestClient(ctx, t, server.URL)
defer client1.CloseWithBye()
require.NoError(client1.SendHello(key))
if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) {
assert.NotEmpty(hello.Hello.SessionId, "%+v", hello)
}
_, err := client1.RunUntilLoad(ctx, 0)
assert.NoError(err)
publisherId := "the-publisher-id"
require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{
Id: "2345",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "create-publisher",
PublisherId: publisherId,
Sid: "1234-abcd",
StreamType: signaling.StreamTypeVideo,
PublisherSettings: &signaling.NewPublisherSettings{
Bitrate: 1234567,
MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo,
},
},
}))
var clientId string
if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("2345", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
clientId = message.Command.Id
}
}
client2 := NewProxyTestClient(ctx, t, server.URL)
defer client2.CloseWithBye()
require.NoError(client2.SendHello(key))
hello2, err := client2.RunUntilHello(ctx)
if assert.NoError(err) {
assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2)
}
_, err = client2.RunUntilLoad(ctx, 0)
assert.NoError(err)
require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{
Id: "3456",
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "publish-remote",
StreamType: signaling.StreamTypeVideo,
ClientId: clientId,
Hostname: "remote-host",
Port: 10001,
RtcpPort: 10002,
},
}))
if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) {
assert.Equal("3456", message.Id)
if err := checkMessageType(message, "command"); assert.NoError(err) {
require.NotEmpty(message.Command.Id)
}
}
if publisher := mcu.publisher.Load(); assert.NotNil(publisher) {
assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId())
if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) {
assert.Equal("remote-host", remoteData.hostname)
assert.EqualValues(10001, remoteData.port)
assert.EqualValues(10002, remoteData.rtcpPort)
}
}
// Closing the session will cause any active remote publishers stop be stopped.
client2.CloseWithBye()
if publisher := mcu.publisher.Load(); assert.NotNil(publisher) {
assert.Empty(publisher.getRemoteId())
assert.Nil(publisher.getRemoteData())
}
}

View file

@ -23,6 +23,7 @@ package main
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
@ -36,6 +37,12 @@ const (
sessionExpirationTime = time.Minute
)
type remotePublisherData struct {
hostname string
port int
rtcpPort int
}
type ProxySession struct {
proxy *ProxyServer
id string
@ -55,6 +62,9 @@ type ProxySession struct {
subscribersLock sync.Mutex
subscribers map[string]signaling.McuSubscriber
subscriberIds map[signaling.McuSubscriber]string
remotePublishersLock sync.Mutex
remotePublishers map[signaling.McuPublisher]map[string]*remotePublisherData
}
func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession {
@ -121,6 +131,7 @@ func (s *ProxySession) Close() {
s.closeFunc()
s.clearPublishers()
s.clearSubscribers()
s.clearRemotePublishers()
s.proxy.DeleteSession(s.Sid())
}
@ -287,6 +298,8 @@ func (s *ProxySession) DeletePublisher(publisher signaling.McuPublisher) string
delete(s.publishers, id)
delete(s.publisherIds, publisher)
delete(s.remotePublishers, publisher)
go s.proxy.PublisherDeleted(publisher)
return id
}
@ -329,6 +342,22 @@ func (s *ProxySession) clearPublishers() {
clear(s.publisherIds)
}
func (s *ProxySession) clearRemotePublishers() {
s.remotePublishersLock.Lock()
defer s.remotePublishersLock.Unlock()
go func(remotePublishers map[signaling.McuPublisher]map[string]*remotePublisherData) {
for publisher, entries := range remotePublishers {
for _, data := range entries {
if err := publisher.UnpublishRemote(context.Background(), s.PublicId(), data.hostname, data.port, data.rtcpPort); err != nil {
log.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), publisher.Id(), data.hostname, err)
}
}
}
}(s.remotePublishers)
s.remotePublishers = nil
}
func (s *ProxySession) clearSubscribers() {
s.publishersLock.Lock()
defer s.publishersLock.Unlock()
@ -349,4 +378,58 @@ func (s *ProxySession) clearSubscribers() {
func (s *ProxySession) NotifyDisconnected() {
s.clearPublishers()
s.clearSubscribers()
s.clearRemotePublishers()
}
func (s *ProxySession) AddRemotePublisher(publisher signaling.McuPublisher, hostname string, port int, rtcpPort int) bool {
s.remotePublishersLock.Lock()
defer s.remotePublishersLock.Unlock()
remote, found := s.remotePublishers[publisher]
if !found {
remote = make(map[string]*remotePublisherData)
if s.remotePublishers == nil {
s.remotePublishers = make(map[signaling.McuPublisher]map[string]*remotePublisherData)
}
s.remotePublishers[publisher] = remote
}
key := fmt.Sprintf("%s:%d%d", hostname, port, rtcpPort)
if _, found := remote[key]; found {
return false
}
data := &remotePublisherData{
hostname: hostname,
port: port,
rtcpPort: rtcpPort,
}
remote[key] = data
return true
}
func (s *ProxySession) RemoveRemotePublisher(publisher signaling.McuPublisher, hostname string, port int, rtcpPort int) {
s.remotePublishersLock.Lock()
defer s.remotePublishersLock.Unlock()
remote, found := s.remotePublishers[publisher]
if !found {
return
}
key := fmt.Sprintf("%s:%d%d", hostname, port, rtcpPort)
delete(remote, key)
if len(remote) == 0 {
delete(s.remotePublishers, publisher)
if len(s.remotePublishers) == 0 {
s.remotePublishers = nil
}
}
}
func (s *ProxySession) OnPublisherDeleted(publisher signaling.McuPublisher) {
s.remotePublishersLock.Lock()
defer s.remotePublishersLock.Unlock()
delete(s.remotePublishers, publisher)
}