First draft of remote subscriber streams.

This commit is contained in:
Joachim Bauch 2024-04-18 12:47:29 +02:00
parent dcb7b078b1
commit 6fa606d44b
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
13 changed files with 1424 additions and 165 deletions

View file

@ -24,6 +24,7 @@ package signaling
import (
"encoding/json"
"fmt"
"net/url"
"github.com/golang-jwt/jwt/v4"
)
@ -201,6 +202,14 @@ type CommandProxyClientMessage struct {
ClientId string `json:"clientId,omitempty"`
Bitrate int `json:"bitrate,omitempty"`
MediaTypes MediaType `json:"mediatypes,omitempty"`
RemoteUrl string `json:"remoteUrl,omitempty"`
remoteUrl *url.URL
RemoteToken string `json:"remoteToken,omitempty"`
Hostname string `json:"hostname,omitempty"`
Port int `json:"port,omitempty"`
RtcpPort int `json:"rtcpPort,omitempty"`
}
func (m *CommandProxyClientMessage) CheckValid() error {
@ -218,6 +227,20 @@ func (m *CommandProxyClientMessage) CheckValid() error {
if m.StreamType == "" {
return fmt.Errorf("stream type missing")
}
if m.RemoteUrl != "" {
// TODO: Enable once subscriber deletion provides this value.
/*
if m.RemoteToken == "" {
return fmt.Errorf("remote token type missing")
}
*/
remoteUrl, err := url.Parse(m.RemoteUrl)
if err != nil {
return fmt.Errorf("invalid remote url: %w", err)
}
m.remoteUrl = remoteUrl
}
case "delete-publisher":
fallthrough
case "delete-subscriber":

View file

@ -949,9 +949,10 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s
subscriber, found := s.subscribers[getStreamId(id, streamType)]
if !found {
client := s.getClientUnlocked()
s.mu.Unlock()
var err error
subscriber, err = mcu.NewSubscriber(ctx, s, id, streamType)
subscriber, err = mcu.NewSubscriber(ctx, s, id, streamType, client)
s.mu.Lock()
if err != nil {
return nil, err

View file

@ -99,6 +99,9 @@ The running container can be configured through different environment variables:
- `CONFIG`: Optional name of configuration file to use.
- `HTTP_LISTEN`: Address of HTTP listener.
- `COUNTRY`: Optional ISO 3166 country this proxy is located at.
- `EXTERNAL_HOSTNAME`: The external hostname for remote streams. Will try to autodetect if omitted.
- `TOKEN_ID`: Id of the token to use when connecting remote streams.
- `TOKEN_KEY`: Private key for the configured token id.
- `JANUS_URL`: Url to Janus server.
- `MAX_STREAM_BITRATE`: Optional maximum bitrate for audio/video streams.
- `MAX_SCREEN_BITRATE`: Optional maximum bitrate for screensharing streams.

View file

@ -44,6 +44,16 @@ if [ ! -f "$CONFIG" ]; then
sed -i "s|#country =.*|country = $COUNTRY|" "$CONFIG"
fi
if [ -n "$EXTERNAL_HOSTNAME" ]; then
sed -i "s|#hostname =.*|hostname = $EXTERNAL_HOSTNAME|" "$CONFIG"
fi
if [ -n "$TOKEN_ID" ]; then
sed -i "s|#token_id =.*|token_id = $TOKEN_ID|" "$CONFIG"
fi
if [ -n "$TOKEN_KEY" ]; then
sed -i "s|#token_key =.*|token_key = $TOKEN_KEY|" "$CONFIG"
fi
HAS_ETCD=
if [ -n "$ETCD_ENDPOINTS" ]; then
sed -i "s|#endpoints =.*|endpoints = $ETCD_ENDPOINTS|" "$CONFIG"

View file

@ -76,7 +76,18 @@ type Mcu interface {
GetStats() interface{}
NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType StreamType, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error)
NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error)
NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error)
}
type RemotePublisherController interface {
PublisherId() string
StartPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error
}
type RemoteMcu interface {
NewRemotePublisher(ctx context.Context, listener McuListener, controller RemotePublisherController, streamType StreamType) (McuRemotePublisher, error)
NewRemoteSubscriber(ctx context.Context, listener McuListener, publisher McuRemotePublisher) (McuRemoteSubscriber, error)
}
type StreamType string
@ -116,6 +127,8 @@ type McuPublisher interface {
HasMedia(MediaType) bool
SetMedia(MediaType)
PublishRemote(ctx context.Context, hostname string, port int, rtcpPort int) error
}
type McuSubscriber interface {
@ -123,3 +136,18 @@ type McuSubscriber interface {
Publisher() string
}
type McuRemotePublisherProperties interface {
Port() int
RtcpPort() int
}
type McuRemotePublisher interface {
McuClient
McuRemotePublisherProperties
}
type McuRemoteSubscriber interface {
McuSubscriber
}

View file

@ -25,6 +25,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"reflect"
@ -53,6 +54,8 @@ const (
)
var (
ErrRemoteStreamsNotSupported = errors.New("Need Janus 1.1.0 for remote streams")
streamTypeUserIds = map[StreamType]uint64{
StreamTypeVideo: videoPublisherUserId,
StreamTypeScreen: screenPublisherUserId,
@ -143,6 +146,7 @@ type mcuJanus struct {
gw *JanusGateway
session *JanusSession
handle *JanusHandle
version int
closeChan chan struct{}
@ -154,6 +158,7 @@ type mcuJanus struct {
publishers map[string]*mcuJanusPublisher
publisherCreated Notifier
publisherConnected Notifier
remotePublishers map[string]*mcuJanusRemotePublisher
reconnectTimer *time.Timer
reconnectInterval time.Duration
@ -189,7 +194,8 @@ func NewMcuJanus(url string, config *goconf.ConfigFile) (Mcu, error) {
closeChan: make(chan struct{}, 1),
clients: make(map[clientInterface]bool),
publishers: make(map[string]*mcuJanusPublisher),
publishers: make(map[string]*mcuJanusPublisher),
remotePublishers: make(map[string]*mcuJanusRemotePublisher),
reconnectInterval: initialReconnectInterval,
}
@ -288,6 +294,10 @@ func (m *mcuJanus) isMultistream() bool {
return m.version >= 1000
}
func (m *mcuJanus) hasRemotePublisher() bool {
return m.version >= 1100
}
func (m *mcuJanus) Start() error {
ctx := context.TODO()
info, err := m.gw.Info(ctx)
@ -727,17 +737,7 @@ func min(a, b int) int {
return b
}
func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType StreamType, bitrate int) (*JanusHandle, uint64, uint64, int, error) {
session := m.session
if session == nil {
return nil, 0, 0, 0, ErrNotConnected
}
handle, err := session.Attach(ctx, pluginVideoRoom)
if err != nil {
return nil, 0, 0, 0, err
}
log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id)
func (m *mcuJanus) createPublisherRoom(ctx context.Context, handle *JanusHandle, id string, streamType StreamType, bitrate int) (uint64, int, error) {
create_msg := map[string]interface{}{
"request": "create",
"description": getStreamId(id, streamType),
@ -764,7 +764,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st
if _, err2 := handle.Detach(ctx); err2 != nil {
log.Printf("Error detaching handle %d: %s", handle.Id, err2)
}
return nil, 0, 0, 0, err
return 0, 0, err
}
roomId := getPluginIntValue(create_response.PluginData, pluginVideoRoom, "room")
@ -772,10 +772,32 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st
if _, err := handle.Detach(ctx); err != nil {
log.Printf("Error detaching handle %d: %s", handle.Id, err)
}
return nil, 0, 0, 0, fmt.Errorf("No room id received: %+v", create_response)
return 0, 0, fmt.Errorf("No room id received: %+v", create_response)
}
log.Println("Created room", roomId, create_response.PluginData)
return roomId, bitrate, nil
}
func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, streamType StreamType, bitrate int) (*JanusHandle, uint64, uint64, int, error) {
session := m.session
if session == nil {
return nil, 0, 0, 0, ErrNotConnected
}
handle, err := session.Attach(ctx, pluginVideoRoom)
if err != nil {
return nil, 0, 0, 0, err
}
log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id)
roomId, bitrate, err := m.createPublisherRoom(ctx, handle, id, streamType, bitrate)
if err != nil {
if _, err2 := handle.Detach(ctx); err2 != nil {
log.Printf("Error detaching handle %d: %s", handle.Id, err2)
}
return nil, 0, 0, 0, err
}
msg := map[string]interface{}{
"request": "join",
@ -983,6 +1005,97 @@ func (p *mcuJanusPublisher) SendMessage(ctx context.Context, message *MessageCli
}
}
func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, hostname string, port int, rtcpPort int) error {
msg := map[string]interface{}{
"request": "publish_remotely",
"room": p.roomId,
"publisher_id": streamTypeUserIds[p.streamType],
"remote_id": p.id,
"host": hostname,
"port": port,
"rtcp_port": rtcpPort,
}
response, err := p.handle.Request(ctx, msg)
if err != nil {
return err
}
errorMessage := getPluginStringValue(response.PluginData, pluginVideoRoom, "error")
errorCode := getPluginIntValue(response.PluginData, pluginVideoRoom, "error_code")
if errorMessage != "" || errorCode != 0 {
if errorMessage == "" {
errorMessage = "unknown error"
}
return fmt.Errorf("%s (%d)", errorMessage, errorCode)
}
log.Printf("Publishing %s to %s (port=%d, rtcpPort=%d)", p.id, hostname, port, rtcpPort)
return nil
}
type mcuJanusRemotePublisher struct {
mcuJanusClient
ref atomic.Int64
publisher string
port int
rtcpPort int
}
func (p *mcuJanusRemotePublisher) addRef() int64 {
return p.ref.Add(1)
}
func (p *mcuJanusRemotePublisher) release() bool {
return p.ref.Add(-1) == 0
}
func (p *mcuJanusRemotePublisher) Port() int {
return p.port
}
func (p *mcuJanusRemotePublisher) RtcpPort() int {
return p.rtcpPort
}
func (p *mcuJanusRemotePublisher) Close(ctx context.Context) {
if !p.release() {
return
}
p.mu.Lock()
if handle := p.handle; handle != nil {
response, err := p.handle.Request(ctx, map[string]interface{}{
"request": "remove_remote_publisher",
"room": p.roomId,
"id": streamTypeUserIds[p.streamType],
})
if err != nil {
log.Printf("Error removing remote publisher %d in room %d: %s", p.id, p.roomId, err)
} else {
log.Printf("Removed remote publisher: %+v", response)
}
if p.roomId != 0 {
destroy_msg := map[string]interface{}{
"request": "destroy",
"room": p.roomId,
}
if _, err := handle.Request(ctx, destroy_msg); err != nil {
log.Printf("Error destroying room %d: %s", p.roomId, err)
} else {
log.Printf("Room %d destroyed", p.roomId)
}
p.mcu.mu.Lock()
delete(p.mcu.remotePublishers, getStreamId(p.publisher, p.streamType))
p.mcu.mu.Unlock()
p.roomId = 0
}
}
p.closeClient(ctx)
p.mu.Unlock()
}
type mcuJanusSubscriber struct {
mcuJanusClient
@ -1037,7 +1150,7 @@ func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher st
return handle, pub, nil
}
func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) {
func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) {
if _, found := streamTypeUserIds[streamType]; !found {
return nil, fmt.Errorf("Unsupported stream type %s", streamType)
}
@ -1078,6 +1191,186 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ
return client, nil
}
type mcuJanusRemoteSubscriber struct {
mcuJanusSubscriber
remote atomic.Pointer[mcuJanusRemotePublisher]
}
func (s *mcuJanusRemoteSubscriber) Close(ctx context.Context) {
s.mcuJanusSubscriber.Close(ctx)
if remote := s.remote.Swap(nil); remote != nil {
remote.Close(context.Background())
}
}
func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller RemotePublisherController, streamType StreamType, bitrate int) (*mcuJanusRemotePublisher, error) {
m.mu.Lock()
defer m.mu.Unlock()
pub, found := m.remotePublishers[getStreamId(controller.PublisherId(), streamType)]
if found {
return pub, nil
}
session := m.session
if session == nil {
return nil, ErrNotConnected
}
handle, err := session.Attach(ctx, pluginVideoRoom)
if err != nil {
return nil, err
}
roomId, bitrate, err := m.createPublisherRoom(ctx, handle, controller.PublisherId(), streamType, bitrate)
if err != nil {
if _, err2 := handle.Detach(ctx); err2 != nil {
log.Printf("Error detaching handle %d: %s", handle.Id, err2)
}
return nil, err
}
response, err := handle.Request(ctx, map[string]interface{}{
"request": "add_remote_publisher",
"room": roomId,
"id": streamTypeUserIds[streamType],
"streams": []map[string]interface{}{
{
"mid": "0",
"mindex": 0,
"type": "audio",
"codec": "opus",
"fec": true,
},
{
"mid": "1",
"mindex": 1,
"type": "video",
"codec": "vp8",
"simulcast": true,
},
{
"mid": "2",
"mindex": 2,
"type": "data",
},
},
})
if err != nil {
if _, err2 := handle.Detach(ctx); err2 != nil {
log.Printf("Error detaching handle %d: %s", handle.Id, err2)
}
return nil, err
}
id := getPluginIntValue(response.PluginData, pluginVideoRoom, "id")
port := getPluginIntValue(response.PluginData, pluginVideoRoom, "port")
rtcp_port := getPluginIntValue(response.PluginData, pluginVideoRoom, "rtcp_port")
pub = &mcuJanusRemotePublisher{
mcuJanusClient: mcuJanusClient{
mcu: m,
id: id,
session: response.Session,
roomId: roomId,
sid: strconv.FormatUint(handle.Id, 10),
streamType: streamType,
maxBitrate: bitrate,
handle: handle,
handleId: handle.Id,
closeChan: make(chan struct{}, 1),
deferred: make(chan func(), 64),
},
publisher: controller.PublisherId(),
port: int(port),
rtcpPort: int(rtcp_port),
}
if err := controller.StartPublishing(ctx, pub); err != nil {
go pub.Close(context.Background())
return nil, err
}
m.remotePublishers[getStreamId(controller.PublisherId(), streamType)] = pub
return pub, nil
}
func (m *mcuJanus) NewRemotePublisher(ctx context.Context, listener McuListener, controller RemotePublisherController, streamType StreamType) (McuRemotePublisher, error) {
if _, found := streamTypeUserIds[streamType]; !found {
return nil, fmt.Errorf("Unsupported stream type %s", streamType)
}
if !m.hasRemotePublisher() {
return nil, ErrRemoteStreamsNotSupported
}
pub, err := m.getOrCreateRemotePublisher(ctx, controller, streamType, 0)
if err != nil {
return nil, err
}
pub.addRef()
return pub, nil
}
func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener, publisher McuRemotePublisher) (McuRemoteSubscriber, error) {
pub, ok := publisher.(*mcuJanusRemotePublisher)
if !ok {
return nil, errors.New("unsupported remote publisher")
}
session := m.session
if session == nil {
return nil, ErrNotConnected
}
handle, err := session.Attach(ctx, pluginVideoRoom)
if err != nil {
return nil, err
}
log.Printf("Attached subscriber to room %d of publisher %s in plugin %s in session %d as %d", pub.roomId, pub.publisher, pluginVideoRoom, session.Id, handle.Id)
client := &mcuJanusRemoteSubscriber{
mcuJanusSubscriber: mcuJanusSubscriber{
mcuJanusClient: mcuJanusClient{
mcu: m,
listener: listener,
id: m.clientId.Add(1),
roomId: pub.roomId,
sid: strconv.FormatUint(handle.Id, 10),
streamType: publisher.StreamType(),
maxBitrate: pub.MaxBitrate(),
handle: handle,
handleId: handle.Id,
closeChan: make(chan struct{}, 1),
deferred: make(chan func(), 64),
},
publisher: pub.publisher,
},
}
client.remote.Store(pub)
pub.addRef()
client.mcuJanusClient.handleEvent = client.handleEvent
client.mcuJanusClient.handleHangup = client.handleHangup
client.mcuJanusClient.handleDetached = client.handleDetached
client.mcuJanusClient.handleConnected = client.handleConnected
client.mcuJanusClient.handleSlowLink = client.handleSlowLink
client.mcuJanusClient.handleMedia = client.handleMedia
m.registerClient(client)
go client.run(handle, client.closeChan)
statsSubscribersCurrent.WithLabelValues(string(publisher.StreamType())).Inc()
statsSubscribersTotal.WithLabelValues(string(publisher.StreamType())).Inc()
return client, nil
}
func (p *mcuJanusSubscriber) Publisher() string {
return p.publisher
}

View file

@ -217,13 +217,18 @@ func (p *mcuProxyPublisher) ProcessEvent(msg *EventProxyServerMessage) {
}
}
func (p *mcuProxyPublisher) PublishRemote(ctx context.Context, hostname string, port int, rtcpPort int) error {
return errors.New("remote publishing not supported for proxy publishers")
}
type mcuProxySubscriber struct {
mcuProxyPubSubCommon
publisherId string
publisherId string
publisherConn *mcuProxyConnection
}
func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType, maxBitrate int, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber {
func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType, maxBitrate int, proxyId string, conn *mcuProxyConnection, listener McuListener, publisherConn *mcuProxyConnection) *mcuProxySubscriber {
return &mcuProxySubscriber{
mcuProxyPubSubCommon: mcuProxyPubSubCommon{
sid: sid,
@ -234,7 +239,8 @@ func newMcuProxySubscriber(publisherId string, sid string, streamType StreamType
listener: listener,
},
publisherId: publisherId,
publisherId: publisherId,
publisherConn: publisherConn,
}
}
@ -257,16 +263,32 @@ func (s *mcuProxySubscriber) Close(ctx context.Context) {
ClientId: s.proxyId,
},
}
if s.publisherConn != nil {
msg.Command.RemoteUrl = s.publisherConn.rawUrl
// TODO: Add remote token for this subscriber.
}
if response, err := s.conn.performSyncRequest(ctx, msg); err != nil {
log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, err)
if s.publisherConn != nil {
log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, err)
} else {
log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, err)
}
return
} else if response.Type == "error" {
log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, response.Error)
if s.publisherConn != nil {
log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, response.Error)
} else {
log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, response.Error)
}
return
}
log.Printf("Delete subscriber %s at %s", s.proxyId, s.conn)
if s.publisherConn != nil {
log.Printf("Delete remote subscriber %s at %s (forwarded to %s)", s.proxyId, s.conn, s.publisherConn)
} else {
log.Printf("Delete subscriber %s at %s", s.proxyId, s.conn)
}
}
func (s *mcuProxySubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) {
@ -371,6 +393,54 @@ func (c *mcuProxyConnection) String() string {
return c.rawUrl
}
func (c *mcuProxyConnection) IsSameCountry(initiator McuInitiator) bool {
if initiator == nil {
return true
}
initiatorCountry := initiator.Country()
if initiatorCountry == "" {
return true
}
connCountry := c.Country()
if connCountry == "" {
return true
}
return initiatorCountry == connCountry
}
func (c *mcuProxyConnection) IsSameContinent(initiator McuInitiator) bool {
if initiator == nil {
return true
}
initiatorCountry := initiator.Country()
if initiatorCountry == "" {
return true
}
connCountry := c.Country()
if connCountry == "" {
return true
}
initiatorContinents, found := ContinentMap[initiatorCountry]
if found {
m := c.proxy.getContinentsMap()
// Map continents to other continents (e.g. use Europe for Africa).
for _, continent := range initiatorContinents {
if toAdd, found := m[continent]; found {
initiatorContinents = append(initiatorContinents, toAdd...)
}
}
}
connContinents := ContinentMap[connCountry]
return ContinentsOverlap(initiatorContinents, connContinents)
}
type mcuProxyConnectionStats struct {
Url string `json:"url"`
IP net.IP `json:"ip,omitempty"`
@ -978,14 +1048,7 @@ func (c *mcuProxyConnection) sendHello() error {
if sessionId := c.SessionId(); sessionId != "" {
msg.Hello.ResumeId = sessionId
} else {
claims := &TokenClaims{
jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: c.proxy.tokenId,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(c.proxy.tokenKey)
tokenString, err := c.proxy.createToken("")
if err != nil {
return err
}
@ -1106,7 +1169,48 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList
proxyId := response.Command.Id
log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c, publisherSessionId)
subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener)
subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, nil)
c.subscribersLock.Lock()
c.subscribers[proxyId] = subscriber
c.subscribersLock.Unlock()
statsSubscribersCurrent.WithLabelValues(string(streamType)).Inc()
statsSubscribersTotal.WithLabelValues(string(streamType)).Inc()
return subscriber, nil
}
func (c *mcuProxyConnection) newRemoteSubscriber(ctx context.Context, listener McuListener, publisherId string, publisherSessionId string, streamType StreamType, publisherConn *mcuProxyConnection) (McuSubscriber, error) {
if c == publisherConn {
return c.newSubscriber(ctx, listener, publisherId, publisherSessionId, streamType)
}
remoteToken, err := c.proxy.createToken(publisherId)
if err != nil {
return nil, err
}
msg := &ProxyClientMessage{
Type: "command",
Command: &CommandProxyClientMessage{
Type: "create-subscriber",
StreamType: streamType,
PublisherId: publisherId,
RemoteUrl: publisherConn.rawUrl,
RemoteToken: remoteToken,
},
}
response, err := c.performSyncRequest(ctx, msg)
if err != nil {
// TODO: Cancel request
return nil, err
} else if response.Type == "error" {
return nil, fmt.Errorf("Error creating remote %s subscriber for %s on %s (forwarded to %s): %+v", streamType, publisherSessionId, c, publisherConn, response.Error)
}
proxyId := response.Command.Id
log.Printf("Created remote %s subscriber %s on %s for %s (forwarded to %s)", streamType, proxyId, c, publisherSessionId, publisherConn)
subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, publisherConn)
c.subscribersLock.Lock()
c.subscribers[proxyId] = subscriber
c.subscribersLock.Unlock()
@ -1289,6 +1393,23 @@ func (m *mcuProxy) Stop() {
m.config.Stop()
}
func (m *mcuProxy) createToken(subject string) (string, error) {
claims := &TokenClaims{
jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: m.tokenId,
Subject: subject,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(m.tokenKey)
if err != nil {
return "", err
}
return tokenString, nil
}
func (m *mcuProxy) hasConnections() bool {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
@ -1681,7 +1802,14 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str
}
}
func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) {
type proxyPublisherInfo struct {
id string
conn *mcuProxyConnection
err error
}
func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) {
var publisherInfo *proxyPublisherInfo
if conn := m.getPublisherConnection(publisher, streamType); conn != nil {
// Fast common path: publisher is available locally.
conn.publishersLock.Lock()
@ -1691,113 +1819,159 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ
return nil, fmt.Errorf("Unknown publisher %s", publisher)
}
return conn.newSubscriber(ctx, listener, id, publisher, streamType)
}
log.Printf("No %s publisher %s found yet, deferring", streamType, publisher)
ch := make(chan McuSubscriber)
getctx, cancel := context.WithCancel(ctx)
defer cancel()
// Wait for publisher to be created locally.
go func() {
if conn := m.waitForPublisherConnection(getctx, publisher, streamType); conn != nil {
cancel() // Cancel pending RPC calls.
conn.publishersLock.Lock()
id, found := conn.publisherIds[getStreamId(publisher, streamType)]
conn.publishersLock.Unlock()
if !found {
log.Printf("Unknown id for local %s publisher %s", streamType, publisher)
return
}
subscriber, err := conn.newSubscriber(ctx, listener, id, publisher, streamType)
if subscriber != nil {
ch <- subscriber
} else if err != nil {
log.Printf("Error creating local subscriber for %s publisher %s: %s", streamType, publisher, err)
}
publisherInfo = &proxyPublisherInfo{
id: id,
conn: conn,
}
}()
} else {
log.Printf("No %s publisher %s found yet, deferring", streamType, publisher)
ch := make(chan *proxyPublisherInfo, 1)
getctx, cancel := context.WithCancel(ctx)
defer cancel()
// Wait for publisher to be created on one of the other servers in the cluster.
if clients := m.rpcClients.GetClients(); len(clients) > 0 {
for _, client := range clients {
go func(client *GrpcClient) {
id, url, ip, err := client.GetPublisherId(getctx, publisher, streamType)
if errors.Is(err, context.Canceled) {
return
} else if err != nil {
log.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err)
return
} else if id == "" {
// Publisher not found on other server
return
}
var wg sync.WaitGroup
// Wait for publisher to be created locally.
wg.Add(1)
go func() {
defer wg.Done()
if conn := m.waitForPublisherConnection(getctx, publisher, streamType); conn != nil {
cancel() // Cancel pending RPC calls.
log.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url)
m.connectionsMu.RLock()
connections := m.connections
m.connectionsMu.RUnlock()
var publisherConn *mcuProxyConnection
for _, conn := range connections {
if conn.rawUrl != url || !ip.Equal(conn.ip) {
continue
conn.publishersLock.Lock()
id, found := conn.publisherIds[getStreamId(publisher, streamType)]
conn.publishersLock.Unlock()
if !found {
ch <- &proxyPublisherInfo{
err: fmt.Errorf("Unknown id for local %s publisher %s", streamType, publisher),
}
// Simple case, signaling server has a connection to the same endpoint
publisherConn = conn
break
}
if publisherConn == nil {
publisherConn, err = newMcuProxyConnection(m, url, ip)
if err != nil {
log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err)
return
}
publisherConn.setTemporary()
publisherConn.start()
if err := publisherConn.waitUntilConnected(ctx); err != nil {
log.Printf("Could not establish new connection to %s: %s", publisherConn, err)
publisherConn.closeIfEmpty()
return
}
m.connectionsMu.Lock()
m.connections = append(m.connections, publisherConn)
conns, found := m.connectionsMap[url]
if found {
conns = append(conns, publisherConn)
} else {
conns = []*mcuProxyConnection{publisherConn}
}
m.connectionsMap[url] = conns
m.connectionsMu.Unlock()
}
subscriber, err := publisherConn.newSubscriber(ctx, listener, id, publisher, streamType)
if err != nil {
if publisherConn.IsTemporary() {
publisherConn.closeIfEmpty()
}
log.Printf("Could not create subscriber for %s publisher %s: %s", streamType, publisher, err)
return
}
ch <- subscriber
}(client)
ch <- &proxyPublisherInfo{
id: id,
conn: conn,
}
}
}()
// Wait for publisher to be created on one of the other servers in the cluster.
if clients := m.rpcClients.GetClients(); len(clients) > 0 {
for _, client := range clients {
wg.Add(1)
go func(client *GrpcClient) {
defer wg.Done()
id, url, ip, err := client.GetPublisherId(getctx, publisher, streamType)
if errors.Is(err, context.Canceled) {
return
} else if err != nil {
log.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err)
return
} else if id == "" {
// Publisher not found on other server
return
}
cancel() // Cancel pending RPC calls.
log.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url)
m.connectionsMu.RLock()
connections := m.connections
m.connectionsMu.RUnlock()
var publisherConn *mcuProxyConnection
for _, conn := range connections {
if conn.rawUrl != url || !ip.Equal(conn.ip) {
continue
}
// Simple case, signaling server has a connection to the same endpoint
publisherConn = conn
break
}
if publisherConn == nil {
publisherConn, err = newMcuProxyConnection(m, url, ip)
if err != nil {
log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err)
return
}
publisherConn.setTemporary()
publisherConn.start()
if err := publisherConn.waitUntilConnected(ctx); err != nil {
log.Printf("Could not establish new connection to %s: %s", publisherConn, err)
publisherConn.closeIfEmpty()
return
}
m.connectionsMu.Lock()
m.connections = append(m.connections, publisherConn)
conns, found := m.connectionsMap[url]
if found {
conns = append(conns, publisherConn)
} else {
conns = []*mcuProxyConnection{publisherConn}
}
m.connectionsMap[url] = conns
m.connectionsMu.Unlock()
}
ch <- &proxyPublisherInfo{
id: id,
conn: publisherConn,
}
}(client)
}
}
wg.Wait()
select {
case ch <- &proxyPublisherInfo{
err: fmt.Errorf("No %s publisher %s found", streamType, publisher),
}:
default:
}
select {
case info := <-ch:
publisherInfo = info
case <-ctx.Done():
return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher)
}
}
select {
case subscriber := <-ch:
return subscriber, nil
case <-ctx.Done():
return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher)
if publisherInfo.err != nil {
return nil, publisherInfo.err
}
if !publisherInfo.conn.IsSameCountry(initiator) {
connections := m.getSortedConnections(initiator)
if len(connections) > 0 && !connections[0].IsSameCountry(publisherInfo.conn) {
// Connect to remote publisher through "closer" gateway.
for _, conn := range connections {
if conn.IsShutdownScheduled() || conn.IsTemporary() || conn == publisherInfo.conn {
continue
}
subscriber, err := conn.newRemoteSubscriber(ctx, listener, publisherInfo.id, publisher, streamType, publisherInfo.conn)
if err != nil {
log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, conn, err)
continue
}
return subscriber, nil
}
}
}
subscriber, err := publisherInfo.conn.newSubscriber(ctx, listener, publisherInfo.id, publisher, streamType)
if err != nil {
if publisherInfo.conn.IsTemporary() {
publisherInfo.conn.closeIfEmpty()
}
log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, publisherInfo.conn, err)
return nil, err
}
return subscriber, nil
}

View file

@ -193,12 +193,14 @@ type testProxyServerSubscriber struct {
id string
sid string
pub *testProxyServerPublisher
remoteUrl string
}
type testProxyServerClient struct {
t *testing.T
server *testProxyServerHandler
server *TestProxyServerHandler
ws *websocket.Conn
processMessage proxyServerClientHandler
@ -284,7 +286,29 @@ func (c *testProxyServerClient) processCommandMessage(msg *ProxyClientMessage) (
case "create-subscriber":
c.mu.Lock()
defer c.mu.Unlock()
pub, found := c.publishers[msg.Command.PublisherId]
var found bool
var pub *testProxyServerPublisher
if msg.Command.RemoteUrl != "" {
for _, server := range c.server.servers {
if server.URL != msg.Command.RemoteUrl {
continue
}
server.mu.Lock()
for _, client := range server.clients {
client.mu.Lock()
pub, found = client.publishers[msg.Command.PublisherId]
client.mu.Unlock()
if found {
break
}
}
server.mu.Unlock()
}
} else {
pub, found = c.publishers[msg.Command.PublisherId]
}
if !found {
response = msg.NewWrappedErrorServerMessage(fmt.Errorf("publisher %s not found", msg.Command.PublisherId))
} else {
@ -292,6 +316,8 @@ func (c *testProxyServerClient) processCommandMessage(msg *ProxyClientMessage) (
id: newRandomString(32),
sid: newRandomString(8),
pub: pub,
remoteUrl: msg.Command.RemoteUrl,
}
response = &ProxyServerMessage{
Id: msg.Id,
@ -311,6 +337,11 @@ func (c *testProxyServerClient) processCommandMessage(msg *ProxyClientMessage) (
if !found {
response = msg.NewWrappedErrorServerMessage(fmt.Errorf("subscriber %s not found", msg.Command.ClientId))
} else {
if msg.Command.RemoteUrl != sub.remoteUrl {
response = msg.NewWrappedErrorServerMessage(fmt.Errorf("remote subscriber %s not found", msg.Command.ClientId))
return response, nil
}
delete(c.subscribers, sub.id)
response = &ProxyServerMessage{
Id: msg.Id,
@ -405,9 +436,12 @@ func (c *testProxyServerClient) run() {
}
}
type testProxyServerHandler struct {
type TestProxyServerHandler struct {
t *testing.T
URL string
server *httptest.Server
servers []*TestProxyServerHandler
upgrader *websocket.Upgrader
country string
@ -416,7 +450,7 @@ type testProxyServerHandler struct {
clients map[string]*testProxyServerClient
}
func (h *testProxyServerHandler) updateLoad(delta int64) {
func (h *TestProxyServerHandler) updateLoad(delta int64) {
if delta == 0 {
return
}
@ -438,7 +472,7 @@ func (h *testProxyServerHandler) updateLoad(delta int64) {
}
}
func (h *testProxyServerHandler) sendLoad(c *testProxyServerClient) {
func (h *TestProxyServerHandler) sendLoad(c *testProxyServerClient) {
c.sendMessage(&ProxyServerMessage{
Type: "event",
Event: &EventProxyServerMessage{
@ -448,13 +482,13 @@ func (h *testProxyServerHandler) sendLoad(c *testProxyServerClient) {
})
}
func (h *testProxyServerHandler) removeClient(client *testProxyServerClient) {
func (h *TestProxyServerHandler) removeClient(client *testProxyServerClient) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.clients, client.sessionId)
}
func (h *testProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *TestProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ws, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
h.t.Error(err)
@ -480,17 +514,19 @@ func (h *testProxyServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques
}(client)
}
func NewProxyServerForTest(t *testing.T, country string) *httptest.Server {
func NewProxyServerForTest(t *testing.T, country string) *TestProxyServerHandler {
t.Helper()
upgrader := websocket.Upgrader{}
proxyHandler := &testProxyServerHandler{
proxyHandler := &TestProxyServerHandler{
t: t,
upgrader: &upgrader,
country: country,
clients: make(map[string]*testProxyServerClient),
}
server := httptest.NewServer(proxyHandler)
proxyHandler.server = server
proxyHandler.URL = server.URL
t.Cleanup(func() {
server.Close()
proxyHandler.mu.Lock()
@ -500,10 +536,10 @@ func NewProxyServerForTest(t *testing.T, country string) *httptest.Server {
}
})
return server
return proxyHandler
}
func newMcuProxyForTestWithServers(t *testing.T, servers []*httptest.Server) *mcuProxy {
func newMcuProxyForTestWithServers(t *testing.T, servers []*TestProxyServerHandler) *mcuProxy {
etcd, etcdClient := NewEtcdClientForTest(t)
grpcClients, dnsMonitor := NewGrpcClientsWithEtcdForTest(t, etcd)
@ -522,6 +558,7 @@ func newMcuProxyForTestWithServers(t *testing.T, servers []*httptest.Server) *mc
var urls []string
waitingMap := make(map[string]bool)
for _, s := range servers {
s.servers = servers
urls = append(urls, s.URL)
waitingMap[s.URL] = true
}
@ -576,7 +613,7 @@ func newMcuProxyForTest(t *testing.T) *mcuProxy {
t.Helper()
server := NewProxyServerForTest(t, "DE")
return newMcuProxyForTestWithServers(t, []*httptest.Server{server})
return newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{server})
}
func Test_ProxyPublisherSubscriber(t *testing.T) {
@ -606,7 +643,10 @@ func Test_ProxyPublisherSubscriber(t *testing.T) {
subListener := &MockMcuListener{
publicId: "subscriber-public",
}
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo)
subInitiator := &MockMcuInitiator{
country: "DE",
}
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
@ -634,10 +674,13 @@ func Test_ProxyWaitForPublisher(t *testing.T) {
subListener := &MockMcuListener{
publicId: "subscriber-public",
}
subInitiator := &MockMcuInitiator{
country: "DE",
}
done := make(chan struct{})
go func() {
defer close(done)
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo)
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Error(err)
return
@ -667,7 +710,7 @@ func Test_ProxyPublisherLoad(t *testing.T) {
t.Parallel()
server1 := NewProxyServerForTest(t, "DE")
server2 := NewProxyServerForTest(t, "DE")
mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
server1,
server2,
})
@ -719,7 +762,7 @@ func Test_ProxyPublisherCountry(t *testing.T) {
t.Parallel()
serverDE := NewProxyServerForTest(t, "DE")
serverUS := NewProxyServerForTest(t, "US")
mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
})
@ -771,7 +814,7 @@ func Test_ProxyPublisherContinent(t *testing.T) {
t.Parallel()
serverDE := NewProxyServerForTest(t, "DE")
serverUS := NewProxyServerForTest(t, "US")
mcu := newMcuProxyForTestWithServers(t, []*httptest.Server{
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
})
@ -817,3 +860,53 @@ func Test_ProxyPublisherContinent(t *testing.T) {
t.Errorf("expected server %s, go %s", serverDE.URL, pubFR.(*mcuProxyPublisher).conn.rawUrl)
}
}
func Test_ProxySubscriberCountry(t *testing.T) {
CatchLogForTest(t)
t.Parallel()
serverDE := NewProxyServerForTest(t, "DE")
serverUS := NewProxyServerForTest(t, "US")
mcu := newMcuProxyForTestWithServers(t, []*TestProxyServerHandler{
serverDE,
serverUS,
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
pubId := "the-publisher"
pubSid := "1234567890"
pubListener := &MockMcuListener{
publicId: pubId + "-public",
}
pubInitiator := &MockMcuInitiator{
country: "DE",
}
pub, err := mcu.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, 0, MediaTypeVideo|MediaTypeAudio, pubInitiator)
if err != nil {
t.Fatal(err)
}
defer pub.Close(context.Background())
if pub.(*mcuProxyPublisher).conn.rawUrl != serverDE.URL {
t.Errorf("expected server %s, go %s", serverDE.URL, pub.(*mcuProxyPublisher).conn.rawUrl)
}
subListener := &MockMcuListener{
publicId: "subscriber-public",
}
subInitiator := &MockMcuInitiator{
country: "US",
}
sub, err := mcu.NewSubscriber(ctx, subListener, pubId, StreamTypeVideo, subInitiator)
if err != nil {
t.Fatal(err)
}
defer sub.Close(context.Background())
if sub.(*mcuProxySubscriber).conn.rawUrl != serverUS.URL {
t.Errorf("expected server %s, go %s", serverUS.URL, sub.(*mcuProxySubscriber).conn.rawUrl)
}
}

View file

@ -23,6 +23,7 @@ package signaling
import (
"context"
"errors"
"fmt"
"log"
"sync"
@ -117,7 +118,7 @@ func (m *TestMCU) GetPublisher(id string) *TestMCUPublisher {
return m.publishers[id]
}
func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType) (McuSubscriber, error) {
func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType StreamType, initiator McuInitiator) (McuSubscriber, error) {
m.mu.Lock()
defer m.mu.Unlock()
@ -222,6 +223,10 @@ func (p *TestMCUPublisher) SendMessage(ctx context.Context, message *MessageClie
}()
}
func (p *TestMCUPublisher) PublishRemote(ctx context.Context, hostname string, port int, rtcpPort int) error {
return errors.New("remote publishing not supported")
}
type TestMCUSubscriber struct {
TestMCUClient

View file

@ -20,6 +20,17 @@
# - etcd: Token information are retrieved from an etcd cluster (see below).
tokentype = static
# The external hostname for remote streams. Leaving this empty will autodetect
# and use the first public IP found on the available network interfaces.
#hostname =
# The token id to use when connecting remote stream.
#token_id = server1
# The private key for the configured token id to use when connecting remote
# streams.
#token_key = privkey.pem
[tokens]
# For token type "static": Mapping of <tokenid> = <publickey> of signaling
# servers allowed to connect.

340
proxy/proxy_remote.go Normal file
View file

@ -0,0 +1,340 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package main
import (
"context"
"crypto/rsa"
"crypto/tls"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"net/url"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/gorilla/websocket"
signaling "github.com/strukturag/nextcloud-spreed-signaling"
)
var (
ErrNotConnected = errors.New("not connected")
)
type RemoteConnection struct {
mu sync.Mutex
url *url.URL
conn *websocket.Conn
tokenId string
tokenKey *rsa.PrivateKey
msgId atomic.Int64
helloMsgId string
sessionId string
messageCallbacks map[string]chan *signaling.ProxyServerMessage
}
func NewRemoteConnection(proxyUrl string, tokenId string, tokenKey *rsa.PrivateKey) (*RemoteConnection, error) {
u, err := url.Parse(proxyUrl)
if err != nil {
return nil, err
}
result := &RemoteConnection{
url: u,
tokenId: tokenId,
tokenKey: tokenKey,
messageCallbacks: make(map[string]chan *signaling.ProxyServerMessage),
}
return result, nil
}
func (c *RemoteConnection) String() string {
return c.url.String()
}
func (c *RemoteConnection) Connect(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
return nil
}
u, err := c.url.Parse("proxy")
if err != nil {
return err
}
if u.Scheme == "http" {
u.Scheme = "ws"
} else if u.Scheme == "https" {
u.Scheme = "wss"
}
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
// TODO: Make this configurable.
InsecureSkipVerify: true,
},
}
conn, _, err := dialer.DialContext(ctx, u.String(), nil)
if err != nil {
return err
}
c.conn = conn
go c.readPump()
return c.sendHello()
}
func (c *RemoteConnection) sendHello() error {
c.helloMsgId = strconv.FormatInt(c.msgId.Add(1), 10)
msg := &signaling.ProxyClientMessage{
Id: c.helloMsgId,
Type: "hello",
Hello: &signaling.HelloProxyClientMessage{
Version: "1.0",
},
}
if sessionId := c.sessionId; sessionId != "" {
msg.Hello.ResumeId = sessionId
} else {
tokenString, err := c.createToken("")
if err != nil {
return err
}
msg.Hello.Token = tokenString
}
return c.sendMessageLocked(msg)
}
func (c *RemoteConnection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return nil
}
err1 := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{})
err2 := c.conn.Close()
c.conn = nil
if err1 != nil {
return err1
}
return err2
}
func (c *RemoteConnection) createToken(subject string) (string, error) {
claims := &signaling.TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: c.tokenId,
Subject: subject,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(c.tokenKey)
if err != nil {
return "", err
}
return tokenString, nil
}
func (c *RemoteConnection) SendMessage(msg *signaling.ProxyClientMessage) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.sendMessageLocked(msg)
}
func (c *RemoteConnection) sendMessageLocked(msg *signaling.ProxyClientMessage) error {
if c.conn == nil {
return ErrNotConnected
}
return c.conn.WriteJSON(msg)
}
func (c *RemoteConnection) readPump() {
for {
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
return
}
msgType, reader, err := conn.NextReader()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
log.Printf("error reading: %s", err)
}
c.mu.Lock()
c.conn = nil
c.mu.Unlock()
return
}
body, err := io.ReadAll(reader)
if err != nil {
log.Printf("error reading message: %s", err)
continue
}
if msgType != websocket.TextMessage {
log.Printf("unexpected message type %q (%s)", msgType, string(body))
continue
}
var msg signaling.ProxyServerMessage
if err := json.Unmarshal(body, &msg); err != nil {
log.Printf("could not decode message %s: %s", string(body), err)
continue
}
c.mu.Lock()
helloMsgId := c.helloMsgId
c.mu.Unlock()
if helloMsgId != "" && msg.Id == helloMsgId {
c.processHello(&msg)
} else {
c.processMessage(&msg)
}
}
}
func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) {
c.mu.Lock()
defer c.mu.Unlock()
c.helloMsgId = ""
switch msg.Type {
case "error":
if msg.Error.Code == "no_such_session" {
log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c)
c.sessionId = ""
if err := c.sendHello(); err != nil {
log.Printf("Could not send hello request to %s: %s", c, err)
// TODO: c.scheduleReconnect()
}
return
}
log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error)
// TODO: c.scheduleReconnect()
case "hello":
resumed := c.sessionId == msg.Hello.SessionId
c.sessionId = msg.Hello.SessionId
country := ""
if msg.Hello.Server != nil {
if country = msg.Hello.Server.Country; country != "" && !signaling.IsValidCountry(country) {
log.Printf("Proxy %s sent invalid country %s in hello response", c, country)
country = ""
}
}
if resumed {
log.Printf("Resumed session %s on %s", c.sessionId, c)
} else if country != "" {
log.Printf("Received session %s from %s (in %s)", c.sessionId, c, country)
} else {
log.Printf("Received session %s from %s", c.sessionId, c)
}
default:
log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c)
// TODO: c.scheduleReconnect()
}
}
func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) {
if msg.Id != "" {
c.mu.Lock()
ch, found := c.messageCallbacks[msg.Id]
if found {
delete(c.messageCallbacks, msg.Id)
c.mu.Unlock()
ch <- msg
return
}
c.mu.Unlock()
}
switch msg.Type {
case "event":
c.processEvent(msg)
default:
log.Printf("Received unsupported message %+v from %s", msg, c)
}
}
func (c *RemoteConnection) processEvent(msg *signaling.ProxyServerMessage) {
switch msg.Event.Type {
case "update-load":
default:
log.Printf("Received unsupported event %+v from %s", msg, c)
}
}
func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.ProxyClientMessage) (*signaling.ProxyServerMessage, error) {
msg.Id = strconv.FormatInt(c.msgId.Add(1), 10)
c.mu.Lock()
defer c.mu.Unlock()
if err := c.sendMessageLocked(msg); err != nil {
return nil, err
}
ch := make(chan *signaling.ProxyServerMessage, 1)
c.messageCallbacks[msg.Id] = ch
c.mu.Unlock()
defer func() {
c.mu.Lock()
delete(c.messageCallbacks, msg.Id)
}()
select {
case <-ctx.Done():
// TODO: Cancel request.
return nil, ctx.Err()
case response := <-ch:
if response.Type == "error" {
return nil, response.Error
}
return response, nil
}
}

View file

@ -24,7 +24,9 @@ package main
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"io"
"log"
@ -63,6 +65,8 @@ const (
// Maximum age a token may have to prevent reuse of old tokens.
maxTokenAge = 5 * time.Minute
remotePublisherTimeout = 5 * time.Second
)
type ContextKey string
@ -70,22 +74,24 @@ type ContextKey string
var (
ContextKeySession = ContextKey("session")
TimeoutCreatingPublisher = signaling.NewError("timeout", "Timeout creating publisher.")
TimeoutCreatingSubscriber = signaling.NewError("timeout", "Timeout creating subscriber.")
TokenAuthFailed = signaling.NewError("auth_failed", "The token could not be authenticated.")
TokenExpired = signaling.NewError("token_expired", "The token is expired.")
TokenNotValidYet = signaling.NewError("token_not_valid_yet", "The token is not valid yet.")
UnknownClient = signaling.NewError("unknown_client", "Unknown client id given.")
UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.")
UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.")
UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.")
ShutdownScheduled = signaling.NewError("shutdown_scheduled", "The server is scheduled to shutdown.")
TimeoutCreatingPublisher = signaling.NewError("timeout", "Timeout creating publisher.")
TimeoutCreatingSubscriber = signaling.NewError("timeout", "Timeout creating subscriber.")
TokenAuthFailed = signaling.NewError("auth_failed", "The token could not be authenticated.")
TokenExpired = signaling.NewError("token_expired", "The token is expired.")
TokenNotValidYet = signaling.NewError("token_not_valid_yet", "The token is not valid yet.")
UnknownClient = signaling.NewError("unknown_client", "Unknown client id given.")
UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.")
UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.")
UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.")
ShutdownScheduled = signaling.NewError("shutdown_scheduled", "The server is scheduled to shutdown.")
RemoteSubscribersNotSupported = signaling.NewError("unsupported_subscriber", "Remote subscribers are not supported.")
)
type ProxyServer struct {
version string
country string
welcomeMessage string
config *goconf.ConfigFile
url string
mcu signaling.Mcu
@ -108,6 +114,47 @@ type ProxyServer struct {
clients map[string]signaling.McuClient
clientIds map[string]string
clientsLock sync.RWMutex
tokenId string
tokenKey *rsa.PrivateKey
remoteHostname string
remoteConnections map[string]*RemoteConnection
remoteConnectionsLock sync.Mutex
}
func IsPublicIP(IP net.IP) bool {
if IP.IsLoopback() || IP.IsLinkLocalMulticast() || IP.IsLinkLocalUnicast() {
return false
}
if ip4 := IP.To4(); ip4 != nil {
switch {
case ip4[0] == 10:
return false
case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31:
return false
case ip4[0] == 192 && ip4[1] == 168:
return false
default:
return true
}
}
return false
}
func GetLocalIP() (string, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", err
}
for _, address := range addrs {
if ipnet, ok := address.(*net.IPNet); ok && IsPublicIP(ipnet.IP) {
if ipnet.IP.To4() != nil {
return ipnet.IP.String(), nil
}
}
}
return "", nil
}
func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*ProxyServer, error) {
@ -173,10 +220,45 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*
return nil, err
}
tokenId, _ := config.GetString("app", "token_id")
var tokenKey *rsa.PrivateKey
var remoteHostname string
if tokenId != "" {
tokenKeyFilename, _ := config.GetString("app", "token_key")
if tokenKeyFilename == "" {
return nil, fmt.Errorf("No token key configured")
}
tokenKeyData, err := os.ReadFile(tokenKeyFilename)
if err != nil {
return nil, fmt.Errorf("Could not read private key from %s: %s", tokenKeyFilename, err)
}
tokenKey, err = jwt.ParseRSAPrivateKeyFromPEM(tokenKeyData)
if err != nil {
return nil, fmt.Errorf("Could not parse private key from %s: %s", tokenKeyFilename, err)
}
log.Printf("Using \"%s\" as token id for remote streams", tokenId)
remoteHostname, _ = config.GetString("app", "hostname")
if remoteHostname == "" {
remoteHostname, err = GetLocalIP()
if err != nil {
return nil, fmt.Errorf("could not get local ip: %w", err)
}
}
if remoteHostname == "" {
log.Printf("WARNING: Could not determine hostname for remote streams, will be disabled. Please configure manually.")
} else {
log.Printf("Using \"%s\" as hostname for remote streams", remoteHostname)
}
} else {
log.Printf("No token id configured, remote streams will be disabled")
}
result := &ProxyServer{
version: version,
country: country,
welcomeMessage: string(welcomeMessage) + "\n",
config: config,
shutdownChannel: make(chan struct{}),
@ -193,6 +275,11 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*
clients: make(map[string]signaling.McuClient),
clientIds: make(map[string]string),
tokenId: tokenId,
tokenKey: tokenKey,
remoteHostname: remoteHostname,
remoteConnections: make(map[string]*RemoteConnection),
}
result.upgrader.CheckOrigin = result.checkOrigin
@ -613,6 +700,40 @@ func (i *emptyInitiator) Country() string {
return ""
}
type proxyRemotePublisher struct {
proxy *ProxyServer
remoteUrl string
publisherId string
}
func (p *proxyRemotePublisher) PublisherId() string {
return p.publisherId
}
func (p *proxyRemotePublisher) StartPublishing(ctx context.Context, publisher signaling.McuRemotePublisherProperties) error {
var conn *RemoteConnection
conn, err := p.proxy.getRemoteConnection(ctx, p.remoteUrl)
if err != nil {
return err
}
if _, err := conn.RequestMessage(ctx, &signaling.ProxyClientMessage{
Type: "command",
Command: &signaling.CommandProxyClientMessage{
Type: "publish-remote",
ClientId: p.publisherId,
Hostname: p.proxy.remoteHostname,
Port: publisher.Port(),
RtcpPort: publisher.RtcpPort(),
},
}); err != nil {
return err
}
return nil
}
func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) {
cmd := message.Command
@ -655,18 +776,74 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
case "create-subscriber":
id := uuid.New().String()
publisherId := cmd.PublisherId
subscriber, err := s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType)
if err == context.DeadlineExceeded {
log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId())
session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber))
return
} else if err != nil {
var subscriber signaling.McuSubscriber
var err error
handleCreateError := func(err error) {
if err == context.DeadlineExceeded {
log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId())
session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber))
return
} else if errors.Is(err, signaling.ErrRemoteStreamsNotSupported) {
session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported))
return
}
log.Printf("Error while creating %s subscriber on %s for %s: %s", cmd.StreamType, publisherId, session.PublicId(), err)
session.sendMessage(message.NewWrappedErrorServerMessage(err))
return
}
log.Printf("Created %s subscriber %s as %s for %s", cmd.StreamType, subscriber.Id(), id, session.PublicId())
if cmd.RemoteUrl != "" {
if s.tokenId == "" || s.tokenKey == nil || s.remoteHostname == "" {
session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported))
return
}
remoteMcu, ok := s.mcu.(signaling.RemoteMcu)
if !ok {
session.sendMessage(message.NewErrorServerMessage(RemoteSubscribersNotSupported))
return
}
subCtx, cancel := context.WithTimeout(ctx, remotePublisherTimeout)
defer cancel()
log.Printf("Creating remote subscriber for %s on %s", publisherId, cmd.RemoteUrl)
controller := &proxyRemotePublisher{
proxy: s,
remoteUrl: cmd.RemoteUrl,
publisherId: publisherId,
}
var publisher signaling.McuRemotePublisher
publisher, err = remoteMcu.NewRemotePublisher(subCtx, session, controller, cmd.StreamType)
if err != nil {
handleCreateError(err)
return
}
defer func() {
go publisher.Close(context.Background())
}()
subscriber, err = remoteMcu.NewRemoteSubscriber(subCtx, session, publisher)
if err != nil {
handleCreateError(err)
return
}
log.Printf("Created remote %s subscriber %s as %s for %s on %s", cmd.StreamType, subscriber.Id(), id, session.PublicId(), cmd.RemoteUrl)
} else {
subscriber, err = s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType, &emptyInitiator{})
if err != nil {
handleCreateError(err)
return
}
log.Printf("Created %s subscriber %s as %s for %s", cmd.StreamType, subscriber.Id(), id, session.PublicId())
}
session.StoreSubscriber(ctx, id, subscriber)
s.StoreClient(id, subscriber)
@ -743,6 +920,33 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
client.Close(context.Background())
}()
response := &signaling.ProxyServerMessage{
Id: message.Id,
Type: "command",
Command: &signaling.CommandProxyServerMessage{
Id: cmd.ClientId,
},
}
session.sendMessage(response)
case "publish-remote":
client := s.GetClient(cmd.ClientId)
if client == nil {
session.sendMessage(message.NewErrorServerMessage(UnknownClient))
return
}
publisher, ok := client.(signaling.McuPublisher)
if !ok {
session.sendMessage(message.NewErrorServerMessage(UnknownClient))
return
}
if err := publisher.PublishRemote(ctx, cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil {
log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err)
session.sendMessage(message.NewWrappedErrorServerMessage(err))
return
}
response := &signaling.ProxyServerMessage{
Id: message.Id,
Type: "command",
@ -994,6 +1198,22 @@ func (s *ProxyServer) GetClient(id string) signaling.McuClient {
return s.clients[id]
}
func (s *ProxyServer) GetPublisher(publisherId string) signaling.McuPublisher {
s.clientsLock.RLock()
defer s.clientsLock.RUnlock()
for _, c := range s.clients {
pub, ok := c.(signaling.McuPublisher)
if !ok {
continue
}
if pub.Id() == publisherId {
return pub
}
}
return nil
}
func (s *ProxyServer) GetClientId(client signaling.McuClient) string {
s.clientsLock.RLock()
defer s.clientsLock.RUnlock()
@ -1055,3 +1275,25 @@ func (s *ProxyServer) metricsHandler(w http.ResponseWriter, r *http.Request) {
// Expose prometheus metrics at "/metrics".
promhttp.Handler().ServeHTTP(w, r)
}
func (s *ProxyServer) getRemoteConnection(ctx context.Context, url string) (*RemoteConnection, error) {
s.remoteConnectionsLock.Lock()
defer s.remoteConnectionsLock.Unlock()
conn, found := s.remoteConnections[url]
if found {
return conn, nil
}
conn, err := NewRemoteConnection(url, s.tokenId, s.tokenKey)
if err != nil {
return nil, err
}
if err := conn.Connect(ctx); err != nil {
return nil, err
}
s.remoteConnections[url] = conn
return conn, nil
}

View file

@ -26,6 +26,7 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"net"
"os"
"testing"
"time"
@ -120,3 +121,38 @@ func TestTokenInFuture(t *testing.T) {
t.Errorf("could have failed with TokenNotValidYet, got %s", err)
}
}
func TestPublicIPs(t *testing.T) {
public := []string{
"8.8.8.8",
"172.15.1.2",
"172.32.1.2",
"192.167.0.1",
"192.169.0.1",
}
private := []string{
"127.0.0.1",
"10.1.2.3",
"172.16.1.2",
"172.31.1.2",
"192.168.0.1",
"192.168.254.254",
}
for _, s := range public {
ip := net.ParseIP(s)
if len(ip) == 0 {
t.Errorf("invalid IP: %s", s)
} else if !IsPublicIP(ip) {
t.Errorf("should be public IP: %s", s)
}
}
for _, s := range private {
ip := net.ParseIP(s)
if len(ip) == 0 {
t.Errorf("invalid IP: %s", s)
} else if IsPublicIP(ip) {
t.Errorf("should be private IP: %s", s)
}
}
}