From 5a553fcc2dfd539de316cfa0fc054f789a8c9eea Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 7 Aug 2020 10:22:27 +0200 Subject: [PATCH 1/8] Move some signaling-specific functions of client to hub to allow reuse. --- src/signaling/client.go | 127 +++++++++++----------------------------- src/signaling/hub.go | 114 +++++++++++++++++++++++++++++++----- 2 files changed, 133 insertions(+), 108 deletions(-) diff --git a/src/signaling/client.go b/src/signaling/client.go index e1535c4..3523792 100644 --- a/src/signaling/client.go +++ b/src/signaling/client.go @@ -25,7 +25,6 @@ import ( "bytes" "encoding/json" "log" - "net" "strconv" "strings" "sync" @@ -52,14 +51,11 @@ const ( ) var ( - _noCountry string = "no-country" - noCountry *string = &_noCountry + noCountry string = "no-country" - _loopback string = "loopback" - loopback *string = &_loopback + loopback string = "loopback" - _unknownCountry string = "unknown-country" - unknownCountry *string = &_unknownCountry + unknownCountry string = "unknown-country" ) var ( @@ -72,8 +68,13 @@ var ( } ) +type WritableClientMessage interface { + json.Marshaler + + CloseAfterSend(session Session) bool +} + type Client struct { - hub *Hub conn *websocket.Conn addr string agent string @@ -85,9 +86,13 @@ type Client struct { mu sync.Mutex closeChan chan bool + + OnLookupCountry func(*Client) string + OnClosed func(*Client) + OnMessageReceived func(*Client, []byte) } -func NewClient(hub *Hub, conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { +func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { remoteAddress = strings.TrimSpace(remoteAddress) if remoteAddress == "" { remoteAddress = "unknown remote address" @@ -97,15 +102,27 @@ func NewClient(hub *Hub, conn *websocket.Conn, remoteAddress string, agent strin agent = "unknown user agent" } client := &Client{ - hub: hub, conn: conn, addr: remoteAddress, agent: agent, closeChan: make(chan bool, 1), + + OnLookupCountry: func(client *Client) string { return unknownCountry }, + OnClosed: func(client *Client) {}, + OnMessageReceived: func(client *Client, data []byte) {}, } return client, nil } +func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) { + c.conn = conn + c.addr = remoteAddress + c.closeChan = make(chan bool, 1) + c.OnLookupCountry = func(client *Client) string { return unknownCountry } + c.OnClosed = func(client *Client) {} + c.OnMessageReceived = func(client *Client, data []byte) {} +} + func (c *Client) IsConnected() bool { return atomic.LoadUint32(&c.closed) == 0 } @@ -132,25 +149,7 @@ func (c *Client) UserAgent() string { func (c *Client) Country() string { if c.country == nil { - if c.hub.geoip == nil { - c.country = unknownCountry - return *c.country - } - ip := net.ParseIP(c.RemoteAddr()) - if ip == nil { - c.country = noCountry - return *c.country - } else if ip.IsLoopback() { - c.country = loopback - return *c.country - } - - country, err := c.hub.geoip.LookupCountry(ip) - if err != nil { - log.Printf("Could not lookup country for %s", ip) - c.country = unknownCountry - return *c.country - } + country := c.OnLookupCountry(c) c.country = &country } @@ -164,7 +163,7 @@ func (c *Client) Close() { c.closeChan <- true - c.hub.processUnregister(c) + c.OnClosed(c) c.SetSession(nil) c.mu.Lock() @@ -183,41 +182,6 @@ func (c *Client) SendError(e *Error) bool { return c.SendMessage(message) } -func (c *Client) SendRoom(message *ClientMessage, room *Room) bool { - response := &ServerMessage{ - Type: "room", - } - if message != nil { - response.Id = message.Id - } - if room == nil { - response.Room = &RoomServerMessage{ - RoomId: "", - } - } else { - response.Room = &RoomServerMessage{ - RoomId: room.id, - Properties: room.properties, - } - } - return c.SendMessage(response) -} - -func (c *Client) SendHelloResponse(message *ClientMessage, session *ClientSession) bool { - response := &ServerMessage{ - Id: message.Id, - Type: "hello", - Hello: &HelloServerMessage{ - Version: HelloVersion, - SessionId: session.PublicId(), - ResumeId: session.PrivateId(), - UserId: session.UserId(), - Server: c.hub.GetServerInfo(), - }, - } - return c.SendMessage(response) -} - func (c *Client) SendByeResponse(message *ClientMessage) bool { return c.SendByeResponseWithReason(message, "") } @@ -236,11 +200,11 @@ func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string return c.SendMessage(response) } -func (c *Client) SendMessage(message *ServerMessage) bool { +func (c *Client) SendMessage(message WritableClientMessage) bool { return c.writeMessage(message) } -func (c *Client) readPump() { +func (c *Client) ReadPump() { defer func() { c.Close() }() @@ -312,28 +276,7 @@ func (c *Client) readPump() { break } - var message ClientMessage - if err := message.UnmarshalJSON(decodeBuffer.Bytes()); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) - } else { - log.Printf("Error decoding message from %s: %v", addr, err) - } - c.SendError(InvalidFormat) - continue - } - - if err := message.CheckValid(); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) - } else { - log.Printf("Invalid message %+v from %s: %v", message, addr, err) - } - c.SendMessage(message.NewErrorServerMessage(InvalidFormat)) - continue - } - - c.hub.processMessage(c, &message) + c.OnMessageReceived(c, decodeBuffer.Bytes()) } } @@ -407,7 +350,7 @@ func (c *Client) writeError(e error) bool { return false } -func (c *Client) writeMessage(message *ServerMessage) bool { +func (c *Client) writeMessage(message WritableClientMessage) bool { c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { @@ -417,7 +360,7 @@ func (c *Client) writeMessage(message *ServerMessage) bool { return c.writeMessageLocked(message) } -func (c *Client) writeMessageLocked(message *ServerMessage) bool { +func (c *Client) writeMessageLocked(message WritableClientMessage) bool { if !c.writeInternal(message) { return false } @@ -458,7 +401,7 @@ func (c *Client) sendPing() bool { return true } -func (c *Client) writePump() { +func (c *Client) WritePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() diff --git a/src/signaling/hub.go b/src/signaling/hub.go index 1747081..0bbd173 100644 --- a/src/signaling/hub.go +++ b/src/signaling/hub.go @@ -30,6 +30,7 @@ import ( "fmt" "hash/fnv" "log" + "net" "net/http" "strings" "sync" @@ -633,7 +634,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B h.setDecodedSessionId(privateSessionId, privateSessionName, sessionIdData) h.setDecodedSessionId(publicSessionId, publicSessionName, sessionIdData) - client.SendHelloResponse(message, session) + h.sendHelloResponse(client, message, session) } func (h *Hub) processUnregister(client *Client) *ClientSession { @@ -656,7 +657,28 @@ func (h *Hub) processUnregister(client *Client) *ClientSession { return session } -func (h *Hub) processMessage(client *Client, message *ClientMessage) { +func (h *Hub) processMessage(client *Client, data []byte) { + var message ClientMessage + if err := message.UnmarshalJSON(data); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) + } else { + log.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) + } + client.SendError(InvalidFormat) + return + } + + if err := message.CheckValid(); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + } else { + log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) + } + client.SendMessage(message.NewErrorServerMessage(InvalidFormat)) + return + } + session := client.GetSession() if session == nil { if message.Type != "hello" { @@ -664,19 +686,19 @@ func (h *Hub) processMessage(client *Client, message *ClientMessage) { return } - h.processHello(client, message) + h.processHello(client, &message) return } switch message.Type { case "room": - h.processRoom(client, message) + h.processRoom(client, &message) case "message": - h.processMessageMsg(client, message) + h.processMessageMsg(client, &message) case "control": - h.processControlMsg(client, message) + h.processControlMsg(client, &message) case "bye": - h.processByeMsg(client, message) + h.processByeMsg(client, &message) case "hello": log.Printf("Ignore hello %+v for already authenticated connection %s", message.Hello, session.PublicId()) default: @@ -684,6 +706,21 @@ func (h *Hub) processMessage(client *Client, message *ClientMessage) { } } +func (h *Hub) sendHelloResponse(client *Client, message *ClientMessage, session *ClientSession) bool { + response := &ServerMessage{ + Id: message.Id, + Type: "hello", + Hello: &HelloServerMessage{ + Version: HelloVersion, + SessionId: session.PublicId(), + ResumeId: session.PrivateId(), + UserId: session.UserId(), + Server: h.GetServerInfo(), + }, + } + return client.SendMessage(response) +} + func (h *Hub) processHello(client *Client, message *ClientMessage) { resumeId := message.Hello.ResumeId if resumeId != "" { @@ -728,7 +765,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) { log.Printf("Resume session from %s in %s (%s) %s (private=%s)", client.RemoteAddr(), client.Country(), client.UserAgent(), session.PublicId(), session.PrivateId()) - client.SendHelloResponse(message, clientSession) + h.sendHelloResponse(client, message, clientSession) clientSession.NotifySessionResumed(client) return } @@ -839,6 +876,26 @@ func (h *Hub) disconnectByRoomSessionId(roomSessionId string) { session.Close() } +func (h *Hub) sendRoom(client *Client, message *ClientMessage, room *Room) bool { + response := &ServerMessage{ + Type: "room", + } + if message != nil { + response.Id = message.Id + } + if room == nil { + response.Room = &RoomServerMessage{ + RoomId: "", + } + } else { + response.Room = &RoomServerMessage{ + RoomId: room.id, + Properties: room.properties, + } + } + return client.SendMessage(response) +} + func (h *Hub) processRoom(client *Client, message *ClientMessage) { session := client.GetSession() roomId := message.Room.RoomId @@ -850,7 +907,7 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { // We can handle leaving a room directly. if session.LeaveRoom(true) != nil { // User was in a room before, so need to notify about leaving it. - client.SendRoom(message, nil) + h.sendRoom(client, message, nil) } if session.UserId() == "" && session.ClientType() != HelloClientTypeInternal { h.startWaitAnonymousClientRoom(client) @@ -965,7 +1022,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back if err := session.SubscribeRoomNats(h.nats, roomId, message.Room.SessionId); err != nil { client.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) return } @@ -978,7 +1035,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back client.SendMessage(message.NewWrappedErrorServerMessage(err)) // The client (implicitly) left the room due to an error. session.UnsubscribeRoomNats() - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) return } } @@ -992,7 +1049,7 @@ func (h *Hub) processJoinRoom(client *Client, message *ClientMessage, room *Back if room.Room.Permissions != nil { session.SetPermissions(*room.Room.Permissions) } - client.SendRoom(message, r) + h.sendRoom(client, message, r) h.notifyUserJoinedRoom(r, client, session, room.Room.Session) } @@ -1427,7 +1484,7 @@ func (h *Hub) processRoomDeleted(message *BackendServerRoomRequest) { switch sess := session.(type) { case *ClientSession: if client := sess.GetClient(); client != nil { - client.SendRoom(nil, nil) + h.sendRoom(client, nil, nil) } } } @@ -1477,6 +1534,23 @@ func getRealUserIP(r *http.Request) string { return r.RemoteAddr } +func (h *Hub) lookupClientCountry(client *Client) string { + ip := net.ParseIP(client.RemoteAddr()) + if ip == nil { + return noCountry + } else if ip.IsLoopback() { + return loopback + } + + country, err := h.geoip.LookupCountry(ip) + if err != nil { + log.Printf("Could not lookup country for %s", ip) + return unknownCountry + } + + return country +} + func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { addr := getRealUserIP(r) agent := r.Header.Get("User-Agent") @@ -1487,13 +1561,21 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { return } - client, err := NewClient(h, conn, addr, agent) + client, err := NewClient(conn, addr, agent) if err != nil { log.Printf("Could not create client for %s: %s", addr, err) return } + if h.geoip != nil { + client.OnLookupCountry = h.lookupClientCountry + } + client.OnMessageReceived = h.processMessage + client.OnClosed = func(client *Client) { + h.processUnregister(client) + } + h.processNewClient(client) - go client.writePump() - go client.readPump() + go client.WritePump() + go client.ReadPump() } From f4d4d5fb4de3f9c76495e71ccad3cf9ea22cfd0d Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 7 Aug 2020 10:23:47 +0200 Subject: [PATCH 2/8] Add callbacks for when MCU connection is established/list. --- src/signaling/mcu_common.go | 11 +++++++++- src/signaling/mcu_janus.go | 40 +++++++++++++++++++++++++++++++++++-- src/signaling/mcu_test.go | 6 ++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/signaling/mcu_common.go b/src/signaling/mcu_common.go index 7bfd250..3429e9d 100644 --- a/src/signaling/mcu_common.go +++ b/src/signaling/mcu_common.go @@ -22,6 +22,8 @@ package signaling import ( + "fmt" + "golang.org/x/net/context" ) @@ -31,8 +33,12 @@ const ( McuTypeDefault = McuTypeJanus ) +var ( + ErrNotConnected = fmt.Errorf("Not connected") +) + type McuListener interface { - Session + PublicId() string OnIceCandidate(client McuClient, candidate interface{}) OnIceCompleted(client McuClient) @@ -45,6 +51,9 @@ type Mcu interface { Start() error Stop() + SetOnConnected(func()) + SetOnDisconnected(func()) + GetStats() interface{} NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) diff --git a/src/signaling/mcu_janus.go b/src/signaling/mcu_janus.go index 0acd282..84bd18d 100644 --- a/src/signaling/mcu_janus.go +++ b/src/signaling/mcu_janus.go @@ -28,6 +28,7 @@ import ( "reflect" "strconv" "sync" + "sync/atomic" "time" "github.com/dlintw/goconf" @@ -64,8 +65,6 @@ var ( videoPublisherUserId: streamTypeVideo, screenPublisherUserId: streamTypeScreen, } - - ErrNotConnected = fmt.Errorf("Not connected") ) func getPluginValue(data janus.PluginData, pluginName string, key string) interface{} { @@ -161,8 +160,13 @@ type mcuJanus struct { reconnectInterval time.Duration connectedSince time.Time + onConnected atomic.Value + onDisconnected atomic.Value } +func emptyOnConnected() {} +func emptyOnDisconnected() {} + func NewMcuJanus(url string, config *goconf.ConfigFile, nats NatsClient) (Mcu, error) { maxStreamBitrate, _ := config.GetInt("mcu", "maxstreambitrate") if maxStreamBitrate <= 0 { @@ -190,6 +194,9 @@ func NewMcuJanus(url string, config *goconf.ConfigFile, nats NatsClient) (Mcu, e reconnectInterval: initialReconnectInterval, } + mcu.onConnected.Store(emptyOnConnected) + mcu.onDisconnected.Store(emptyOnDisconnected) + mcu.reconnectTimer = time.AfterFunc(mcu.reconnectInterval, mcu.doReconnect) mcu.reconnectTimer.Stop() if err := mcu.reconnect(); err != nil { @@ -269,6 +276,7 @@ func (m *mcuJanus) scheduleReconnect(err error) { func (m *mcuJanus) ConnectionInterrupted() { m.scheduleReconnect(nil) + m.notifyOnDisconnected() } func (m *mcuJanus) Start() error { @@ -314,6 +322,8 @@ func (m *mcuJanus) Start() error { log.Println("Created Janus handle", m.handle.Id) go m.run() + + m.notifyOnConnected() return nil } @@ -349,6 +359,32 @@ func (m *mcuJanus) Stop() { m.reconnectTimer.Stop() } +func (m *mcuJanus) SetOnConnected(f func()) { + if f == nil { + f = emptyOnConnected + } + + m.onConnected.Store(f) +} + +func (m *mcuJanus) notifyOnConnected() { + f := m.onConnected.Load().(func()) + f() +} + +func (m *mcuJanus) SetOnDisconnected(f func()) { + if f == nil { + f = emptyOnDisconnected + } + + m.onDisconnected.Store(f) +} + +func (m *mcuJanus) notifyOnDisconnected() { + f := m.onDisconnected.Load().(func()) + f() +} + type mcuJanusConnectionStats struct { Url string `json:"url"` Connected bool `json:"connected"` diff --git a/src/signaling/mcu_test.go b/src/signaling/mcu_test.go index 8bdaad3..7062a49 100644 --- a/src/signaling/mcu_test.go +++ b/src/signaling/mcu_test.go @@ -41,6 +41,12 @@ func (m *TestMCU) Start() error { func (m *TestMCU) Stop() { } +func (m *TestMCU) SetOnConnected(f func()) { +} + +func (m *TestMCU) SetOnDisconnected(f func()) { +} + func (m *TestMCU) GetStats() interface{} { return nil } From acbb47a1005a2144dd390cd9e026d4b8f8373189 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 7 Aug 2020 11:39:52 +0200 Subject: [PATCH 3/8] Add callback on received RTT. --- src/signaling/client.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/signaling/client.go b/src/signaling/client.go index 3523792..187ca2d 100644 --- a/src/signaling/client.go +++ b/src/signaling/client.go @@ -90,6 +90,7 @@ type Client struct { OnLookupCountry func(*Client) string OnClosed func(*Client) OnMessageReceived func(*Client, []byte) + OnRTTReceived func(*Client, time.Duration) } func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { @@ -110,6 +111,7 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Clien OnLookupCountry: func(client *Client) string { return unknownCountry }, OnClosed: func(client *Client) {}, OnMessageReceived: func(client *Client, data []byte) {}, + OnRTTReceived: func(client *Client, rtt time.Duration) {}, } return client, nil } @@ -234,6 +236,7 @@ func (c *Client) ReadPump() { } else { log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) } + c.OnRTTReceived(c, rtt) } return nil }) From 4446b079514df3d5e3a6ea80e4bc9af7d807748a Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 7 Aug 2020 10:27:28 +0200 Subject: [PATCH 4/8] Add MCU type "proxy" that delegates to one or multiple MCU proxies. --- Makefile | 1 + dependencies.tsv | 1 + server.conf.in | 20 +- src/server/main.go | 2 + src/signaling/api_proxy.go | 254 +++++++++ src/signaling/mcu_common.go | 1 + src/signaling/mcu_proxy.go | 1035 +++++++++++++++++++++++++++++++++++ 7 files changed, 1309 insertions(+), 5 deletions(-) create mode 100644 src/signaling/api_proxy.go create mode 100644 src/signaling/mcu_proxy.go diff --git a/Makefile b/Makefile index ee0c9c1..9d1cdcf 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,7 @@ coverhtml: dependencies vet common common: easyjson \ src/signaling/api_signaling_easyjson.go \ src/signaling/api_backend_easyjson.go \ + src/signaling/api_proxy_easyjson.go \ src/signaling/natsclient_easyjson.go \ src/signaling/room_easyjson.go diff --git a/dependencies.tsv b/dependencies.tsv index ae64cf9..1a865cd 100644 --- a/dependencies.tsv +++ b/dependencies.tsv @@ -10,3 +10,4 @@ github.com/notedit/janus-go git 8e6e2c423c03884d938d84442d37d6f6f5294197 2017-06 github.com/oschwald/maxminddb-golang git 1960b16a5147df3a4c61ac83b2f31cd8f811d609 2019-05-23T23:57:38Z golang.org/x/net git f01ecb60fe3835d80d9a0b7b2bf24b228c89260e 2017-07-11T18:12:19Z golang.org/x/sys git ac767d655b305d4e9612f5f6e33120b9176c4ad4 2018-07-15T08:55:29Z +gopkg.in/dgrijalva/jwt-go.v3 git 06ea1031745cb8b3dab3f6a236daf2b0aa468b7e 2018-03-08T23:13:08Z diff --git a/server.conf.in b/server.conf.in index 6488e4e..aaf10c2 100644 --- a/server.conf.in +++ b/server.conf.in @@ -98,21 +98,31 @@ connectionsperhost = 8 #url = nats://localhost:4222 [mcu] -# The type of the MCU to use. Currently only "janus" is supported. +# The type of the MCU to use. Currently only "janus" and "proxy" are supported. type = janus -# The URL to the websocket endpoint of the MCU server. Leave empty to disable -# MCU functionality. +# For type "janus": the URL to the websocket endpoint of the MCU server. +# For type "proxy": a space-separated list of proxy URLs to connect to. +# Leave empty to disable MCU functionality. url = -# The maximum bitrate per publishing stream (in bits per second). +# For type "janus": the maximum bitrate per publishing stream (in bits per +# second). # Defaults to 1 mbit/sec. #maxstreambitrate = 1048576 -# The maximum bitrate per screensharing stream (in bits per second). +# For type "janus": the maximum bitrate per screensharing stream (in bits per +# second). # Default is 2 mbit/sec. #maxscreenbitrate = 2097152 +# For type "proxy": the id of the token to use when connecting to proxy servers. +#token_id = server1 + +# For type "proxy": the private key for the configured token id to use when +# connecting to proxy servers. +#token_key = privkey.pem + [turn] # API key that the MCU will need to send when requesting TURN credentials. #apikey = the-api-key-for-the-rest-service diff --git a/src/server/main.go b/src/server/main.go index 468b899..f052aa7 100644 --- a/src/server/main.go +++ b/src/server/main.go @@ -166,6 +166,8 @@ func main() { switch mcuType { case signaling.McuTypeJanus: mcu, err = signaling.NewMcuJanus(mcuUrl, config, nats) + case signaling.McuTypeProxy: + mcu, err = signaling.NewMcuProxy(mcuUrl, config) default: log.Fatal("Unsupported MCU type: ", mcuType) } diff --git a/src/signaling/api_proxy.go b/src/signaling/api_proxy.go new file mode 100644 index 0000000..ad78a4c --- /dev/null +++ b/src/signaling/api_proxy.go @@ -0,0 +1,254 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "fmt" + + "gopkg.in/dgrijalva/jwt-go.v3" +) + +type ProxyClientMessage struct { + // The unique request id (optional). + Id string `json:"id,omitempty"` + + // The type of the request. + Type string `json:"type"` + + // Filled for type "hello" + Hello *HelloProxyClientMessage `json:"hello,omitempty"` + + Bye *ByeProxyClientMessage `json:"bye,omitempty"` + + Command *CommandProxyClientMessage `json:"command,omitempty"` + + Payload *PayloadProxyClientMessage `json:"payload,omitempty"` +} + +func (m *ProxyClientMessage) CheckValid() error { + switch m.Type { + case "": + return fmt.Errorf("type missing") + case "hello": + if m.Hello == nil { + return fmt.Errorf("hello missing") + } else if err := m.Hello.CheckValid(); err != nil { + return err + } + case "bye": + if m.Bye != nil { + // Bye contents are optional + if err := m.Bye.CheckValid(); err != nil { + return err + } + } + case "command": + if m.Command == nil { + return fmt.Errorf("command missing") + } else if err := m.Command.CheckValid(); err != nil { + return err + } + case "payload": + if m.Payload == nil { + return fmt.Errorf("payload missing") + } else if err := m.Payload.CheckValid(); err != nil { + return err + } + } + return nil +} + +func (m *ProxyClientMessage) NewErrorServerMessage(e *Error) *ProxyServerMessage { + return &ProxyServerMessage{ + Id: m.Id, + Type: "error", + Error: e, + } +} + +func (m *ProxyClientMessage) NewWrappedErrorServerMessage(e error) *ProxyServerMessage { + return m.NewErrorServerMessage(NewError("internal_error", e.Error())) +} + +// ProxyServerMessage is a message that is sent from the server to a client. +type ProxyServerMessage struct { + Id string `json:"id,omitempty"` + + Type string `json:"type"` + + Error *Error `json:"error,omitempty"` + + Hello *HelloProxyServerMessage `json:"hello,omitempty"` + + Bye *ByeProxyServerMessage `json:"bye,omitempty"` + + Command *CommandProxyServerMessage `json:"command,omitempty"` + + Payload *PayloadProxyServerMessage `json:"payload,omitempty"` + + Event *EventProxyServerMessage `json:"event,omitempty"` +} + +func (r *ProxyServerMessage) CloseAfterSend(session Session) bool { + if r.Type == "bye" { + return true + } + + return false +} + +// Type "hello" + +type TokenClaims struct { + jwt.StandardClaims +} + +type HelloProxyClientMessage struct { + Version string `json:"version"` + + ResumeId string `json:"resumeid"` + + Features []string `json:"features,omitempty"` + + // The authentication credentials. + Token string `json:"token"` +} + +func (m *HelloProxyClientMessage) CheckValid() error { + if m.Version != HelloVersion { + return fmt.Errorf("unsupported hello version: %s", m.Version) + } + if m.ResumeId == "" { + if m.Token == "" { + return fmt.Errorf("token missing") + } + } + return nil +} + +type HelloProxyServerMessage struct { + Version string `json:"version"` + + SessionId string `json:"sessionid"` + Server *HelloServerMessageServer `json:"server,omitempty"` +} + +// Type "bye" + +type ByeProxyClientMessage struct { +} + +func (m *ByeProxyClientMessage) CheckValid() error { + // No additional validation required. + return nil +} + +type ByeProxyServerMessage struct { + Reason string `json:"reason"` +} + +// Type "command" + +type CommandProxyClientMessage struct { + Type string `json:"type"` + + StreamType string `json:"streamType,omitempty"` + PublisherId string `json:"publisherId,omitempty"` + ClientId string `json:"clientId,omitempty"` +} + +func (m *CommandProxyClientMessage) CheckValid() error { + switch m.Type { + case "": + return fmt.Errorf("type missing") + case "create-publisher": + if m.StreamType == "" { + return fmt.Errorf("stream type missing") + } + case "create-subscriber": + if m.PublisherId == "" { + return fmt.Errorf("publisher id missing") + } + if m.StreamType == "" { + return fmt.Errorf("stream type missing") + } + case "delete-publisher": + fallthrough + case "delete-subscriber": + if m.ClientId == "" { + return fmt.Errorf("client id missing") + } + } + return nil +} + +type CommandProxyServerMessage struct { + Id string `json:"id,omitempty"` +} + +// Type "payload" + +type PayloadProxyClientMessage struct { + Type string `json:"type"` + + ClientId string `json:"clientId"` + Payload map[string]interface{} `json:"payload,omitempty"` +} + +func (m *PayloadProxyClientMessage) CheckValid() error { + switch m.Type { + case "": + return fmt.Errorf("type missing") + case "offer": + fallthrough + case "answer": + fallthrough + case "candidate": + if len(m.Payload) == 0 { + return fmt.Errorf("payload missing") + } + case "endOfCandidates": + fallthrough + case "requestoffer": + // No payload required. + } + if m.ClientId == "" { + return fmt.Errorf("client id missing") + } + return nil +} + +type PayloadProxyServerMessage struct { + Type string `json:"type"` + + ClientId string `json:"clientId"` + Payload map[string]interface{} `json:"payload"` +} + +// Type "event" + +type EventProxyServerMessage struct { + Type string `json:"type"` + + ClientId string `json:"clientId,omitempty"` + Load int64 `json:"load,omitempty"` +} diff --git a/src/signaling/mcu_common.go b/src/signaling/mcu_common.go index 3429e9d..20721af 100644 --- a/src/signaling/mcu_common.go +++ b/src/signaling/mcu_common.go @@ -29,6 +29,7 @@ import ( const ( McuTypeJanus = "janus" + McuTypeProxy = "proxy" McuTypeDefault = McuTypeJanus ) diff --git a/src/signaling/mcu_proxy.go b/src/signaling/mcu_proxy.go new file mode 100644 index 0000000..f5c0bbb --- /dev/null +++ b/src/signaling/mcu_proxy.go @@ -0,0 +1,1035 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "crypto/rsa" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/url" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/dlintw/goconf" + "github.com/gorilla/websocket" + + "golang.org/x/net/context" + + "gopkg.in/dgrijalva/jwt-go.v3" +) + +const ( + closeTimeout = time.Second + + proxyDebugMessages = false + + // Very high value so the connections get sorted at the end. + loadNotConnected = 1000000 + + // Sort connections by load every 10 publishing requests or once per second. + connectionSortRequests = 10 + connectionSortInterval = time.Second +) + +type mcuProxyPubSubCommon struct { + streamType string + proxyId string + conn *mcuProxyConnection + listener McuListener +} + +func (c *mcuProxyPubSubCommon) Id() string { + return c.proxyId +} + +func (c *mcuProxyPubSubCommon) StreamType() string { + return c.streamType +} + +func (c *mcuProxyPubSubCommon) doSendMessage(ctx context.Context, msg *ProxyClientMessage, callback func(error, map[string]interface{})) { + c.conn.performAsyncRequest(ctx, msg, func(err error, response *ProxyServerMessage) { + if err != nil { + callback(err, nil) + return + } + + if proxyDebugMessages { + log.Printf("Response from %s: %+v", c.conn.url, response) + } + if response.Type == "error" { + callback(response.Error, nil) + } else if response.Payload != nil { + callback(nil, response.Payload.Payload) + } else { + callback(nil, nil) + } + }) +} + +func (c *mcuProxyPubSubCommon) doProcessPayload(client McuClient, msg *PayloadProxyServerMessage) { + switch msg.Type { + case "candidate": + c.listener.OnIceCandidate(client, msg.Payload["candidate"]) + default: + log.Printf("Unsupported payload from %s: %+v", c.conn.url, msg) + } +} + +type mcuProxyPublisher struct { + mcuProxyPubSubCommon + + id string +} + +func newMcuProxyPublisher(id string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { + return &mcuProxyPublisher{ + mcuProxyPubSubCommon: mcuProxyPubSubCommon{ + streamType: streamType, + proxyId: proxyId, + conn: conn, + listener: listener, + }, + id: id, + } +} + +func (p *mcuProxyPublisher) NotifyClosed() { + p.listener.PublisherClosed(p) + p.conn.removePublisher(p) +} + +func (p *mcuProxyPublisher) Close(ctx context.Context) { + p.NotifyClosed() + + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "delete-publisher", + ClientId: p.proxyId, + }, + } + + if _, err := p.conn.performSyncRequest(ctx, msg); err != nil { + log.Printf("Could not delete publisher %s at %s: %s", p.proxyId, p.conn.url, err) + return + } +} + +func (p *mcuProxyPublisher) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { + msg := &ProxyClientMessage{ + Type: "payload", + Payload: &PayloadProxyClientMessage{ + Type: data.Type, + ClientId: p.proxyId, + Payload: data.Payload, + }, + } + + p.doSendMessage(ctx, msg, callback) +} + +func (p *mcuProxyPublisher) ProcessPayload(msg *PayloadProxyServerMessage) { + p.doProcessPayload(p, msg) +} + +func (p *mcuProxyPublisher) ProcessEvent(msg *EventProxyServerMessage) { + switch msg.Type { + case "ice-completed": + p.listener.OnIceCompleted(p) + case "publisher-closed": + p.NotifyClosed() + default: + log.Printf("Unsupported event from %s: %+v", p.conn.url, msg) + } +} + +type mcuProxySubscriber struct { + mcuProxyPubSubCommon + + publisherId string +} + +func newMcuProxySubscriber(publisherId string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber { + return &mcuProxySubscriber{ + mcuProxyPubSubCommon: mcuProxyPubSubCommon{ + streamType: streamType, + proxyId: proxyId, + conn: conn, + listener: listener, + }, + + publisherId: publisherId, + } +} + +func (s *mcuProxySubscriber) Publisher() string { + return s.publisherId +} + +func (s *mcuProxySubscriber) NotifyClosed() { + s.listener.SubscriberClosed(s) + s.conn.removeSubscriber(s) +} + +func (s *mcuProxySubscriber) Close(ctx context.Context) { + s.NotifyClosed() + + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "delete-subscriber", + ClientId: s.proxyId, + }, + } + + if _, err := s.conn.performSyncRequest(ctx, msg); err != nil { + log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn.url, err) + return + } +} + +func (s *mcuProxySubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) { + msg := &ProxyClientMessage{ + Type: "payload", + Payload: &PayloadProxyClientMessage{ + Type: data.Type, + ClientId: s.proxyId, + Payload: data.Payload, + }, + } + + s.doSendMessage(ctx, msg, callback) +} + +func (s *mcuProxySubscriber) ProcessPayload(msg *PayloadProxyServerMessage) { + s.doProcessPayload(s, msg) +} + +func (s *mcuProxySubscriber) ProcessEvent(msg *EventProxyServerMessage) { + switch msg.Type { + case "ice-completed": + s.listener.OnIceCompleted(s) + case "subscriber-closed": + s.NotifyClosed() + default: + log.Printf("Unsupported event from %s: %+v", s.conn.url, msg) + } +} + +type mcuProxyConnection struct { + proxy *mcuProxy + url *url.URL + + mu sync.Mutex + closeChan chan bool + closedChan chan bool + closed uint32 + conn *websocket.Conn + + connectedSince time.Time + reconnectInterval int64 + reconnectTimer *time.Timer + + msgId int64 + helloMsgId string + sessionId string + load int64 + + callbacks map[string]func(*ProxyServerMessage) + + publishersLock sync.RWMutex + publishers map[string]*mcuProxyPublisher + publisherIds map[string]string + + subscribersLock sync.RWMutex + subscribers map[string]*mcuProxySubscriber +} + +func newMcuProxyConnection(proxy *mcuProxy, baseUrl string) (*mcuProxyConnection, error) { + parsed, err := url.Parse(baseUrl) + if err != nil { + return nil, err + } + + conn := &mcuProxyConnection{ + proxy: proxy, + url: parsed, + closeChan: make(chan bool, 1), + closedChan: make(chan bool, 1), + reconnectInterval: int64(initialReconnectInterval), + load: loadNotConnected, + callbacks: make(map[string]func(*ProxyServerMessage)), + publishers: make(map[string]*mcuProxyPublisher), + publisherIds: make(map[string]string), + subscribers: make(map[string]*mcuProxySubscriber), + } + return conn, nil +} + +type mcuProxyConnectionStats struct { + Url string `json:"url"` + Connected bool `json:"connected"` + Publishers int64 `json:"publishers"` + Clients int64 `json:"clients"` + Uptime *time.Time `json:"uptime,omitempty"` +} + +func (c *mcuProxyConnection) GetStats() *mcuProxyConnectionStats { + result := &mcuProxyConnectionStats{ + Url: c.url.String(), + } + c.mu.Lock() + if c.conn != nil { + result.Connected = true + result.Uptime = &c.connectedSince + } + c.mu.Unlock() + c.publishersLock.RLock() + result.Publishers = int64(len(c.publishers)) + c.publishersLock.RUnlock() + c.subscribersLock.RLock() + result.Clients = int64(len(c.subscribers)) + c.subscribersLock.RUnlock() + result.Clients += result.Publishers + return result +} + +func (c *mcuProxyConnection) Load() int64 { + return atomic.LoadInt64(&c.load) +} + +func (c *mcuProxyConnection) readPump() { + defer func() { + if atomic.LoadUint32(&c.closed) == 0 { + c.scheduleReconnect() + } else { + c.closedChan <- true + } + }() + defer c.close() + defer atomic.StoreInt64(&c.load, loadNotConnected) + + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + for { + _, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived) { + log.Printf("Error reading from %s: %v", c.url, err) + } + break + } + + var msg ProxyServerMessage + if err := json.Unmarshal(message, &msg); err != nil { + log.Printf("Error unmarshaling %s from %s: %s", string(message), c.url, err) + continue + } + + c.processMessage(&msg) + } +} + +func (c *mcuProxyConnection) writePump() { + c.reconnectTimer = time.NewTimer(0) + for { + select { + case <-c.reconnectTimer.C: + c.reconnect() + case <-c.closeChan: + return + } + } +} + +func (c *mcuProxyConnection) start() error { + go c.writePump() + return nil +} + +func (c *mcuProxyConnection) sendClose() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return ErrNotConnected + } + + return c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) +} + +func (c *mcuProxyConnection) stop(ctx context.Context) { + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return + } + + c.closeChan <- true + if err := c.sendClose(); err != nil { + if err != ErrNotConnected { + log.Printf("Could not send close message to %s: %s", c.url, err) + } + c.close() + return + } + + select { + case <-c.closedChan: + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + log.Printf("Error waiting for connection to %s get closed: %s", c.url, err) + c.close() + } + } +} + +func (c *mcuProxyConnection) close() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } +} + +func (c *mcuProxyConnection) scheduleReconnect() { + if err := c.sendClose(); err != nil && err != ErrNotConnected { + log.Printf("Could not send close message to %s: %s", c.url, err) + c.close() + } + + interval := atomic.LoadInt64(&c.reconnectInterval) + c.reconnectTimer.Reset(time.Duration(interval)) + + interval = interval * 2 + if interval > int64(maxReconnectInterval) { + interval = int64(maxReconnectInterval) + } + atomic.StoreInt64(&c.reconnectInterval, interval) +} + +func (c *mcuProxyConnection) reconnect() { + u, err := c.url.Parse("proxy") + if err != nil { + log.Printf("Could not resolve url to proxy at %s: %s", c.url, err) + c.scheduleReconnect() + return + } + if u.Scheme == "http" { + u.Scheme = "ws" + } else if u.Scheme == "https" { + u.Scheme = "wss" + } + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Printf("Could not connect to %s: %s", u, err) + c.scheduleReconnect() + return + } + + log.Printf("Connected to %s", u) + atomic.StoreUint32(&c.closed, 0) + + c.mu.Lock() + c.connectedSince = time.Now() + c.conn = conn + c.mu.Unlock() + + atomic.StoreInt64(&c.reconnectInterval, int64(initialReconnectInterval)) + if err := c.sendHello(); err != nil { + log.Printf("Could not send hello request to %s: %s", c.url, err) + c.scheduleReconnect() + return + } + + go c.readPump() +} + +func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) { + c.proxy.removePublisher(publisher) + + c.publishersLock.Lock() + defer c.publishersLock.Unlock() + + delete(c.publishers, publisher.proxyId) + delete(c.publisherIds, publisher.id+"|"+publisher.StreamType()) +} + +func (c *mcuProxyConnection) clearPublishers() { + c.publishersLock.Lock() + defer c.publishersLock.Unlock() + + go func(publishers map[string]*mcuProxyPublisher) { + for _, publisher := range publishers { + publisher.NotifyClosed() + } + }(c.publishers) + c.publishers = make(map[string]*mcuProxyPublisher) + c.publisherIds = make(map[string]string) +} + +func (c *mcuProxyConnection) removeSubscriber(subscriber *mcuProxySubscriber) { + c.subscribersLock.Lock() + defer c.subscribersLock.Unlock() + + delete(c.subscribers, subscriber.proxyId) +} + +func (c *mcuProxyConnection) clearSubscribers() { + c.subscribersLock.Lock() + defer c.subscribersLock.Unlock() + + go func(subscribers map[string]*mcuProxySubscriber) { + for _, subscriber := range subscribers { + subscriber.NotifyClosed() + } + }(c.subscribers) + c.subscribers = make(map[string]*mcuProxySubscriber) +} + +func (c *mcuProxyConnection) clearCallbacks() { + c.mu.Lock() + defer c.mu.Unlock() + + c.callbacks = make(map[string]func(*ProxyServerMessage)) +} + +func (c *mcuProxyConnection) getCallback(id string) func(*ProxyServerMessage) { + c.mu.Lock() + defer c.mu.Unlock() + + callback, found := c.callbacks[id] + if found { + delete(c.callbacks, id) + } + return callback +} + +func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { + if c.helloMsgId != "" && msg.Id == c.helloMsgId { + c.helloMsgId = "" + switch msg.Type { + case "error": + if msg.Error.Code == "no_such_session" { + log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c.url) + c.clearPublishers() + c.clearSubscribers() + c.clearCallbacks() + c.sessionId = "" + if err := c.sendHello(); err != nil { + log.Printf("Could not send hello request to %s: %s", c.url, err) + c.scheduleReconnect() + } + return + } + + log.Printf("Hello connection to %s failed with %+v, reconnecting", c.url, msg.Error) + c.scheduleReconnect() + case "hello": + c.sessionId = msg.Hello.SessionId + log.Printf("Received session %s from %s", c.sessionId, c.url) + default: + log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c.url) + c.scheduleReconnect() + } + return + } + + if proxyDebugMessages { + log.Printf("Received from %s: %+v", c.url, msg) + } + callback := c.getCallback(msg.Id) + if callback != nil { + callback(msg) + return + } + + switch msg.Type { + case "payload": + c.processPayload(msg) + case "event": + c.processEvent(msg) + default: + log.Printf("Unsupported message received from %s: %+v", c.url, msg) + } +} + +func (c *mcuProxyConnection) processPayload(msg *ProxyServerMessage) { + payload := msg.Payload + c.publishersLock.RLock() + publisher, found := c.publishers[payload.ClientId] + c.publishersLock.RUnlock() + if found { + publisher.ProcessPayload(payload) + return + } + + c.subscribersLock.RLock() + subscriber, found := c.subscribers[payload.ClientId] + c.subscribersLock.RUnlock() + if found { + subscriber.ProcessPayload(payload) + return + } + + log.Printf("Received payload for unknown client %+v from %s", payload, c.url) +} + +func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) { + event := msg.Event + if event.Type == "backend-disconnected" { + log.Printf("Upstream backend at %s got disconnected, reset MCU objects", c.url) + c.clearPublishers() + c.clearSubscribers() + c.clearCallbacks() + // TODO: Should we also reconnect? + return + } else if event.Type == "backend-connected" { + log.Printf("Upstream backend at %s is connected", c.url) + return + } else if event.Type == "update-load" { + if proxyDebugMessages { + log.Printf("Load of %s now at %d", c.url, event.Load) + } + atomic.StoreInt64(&c.load, event.Load) + return + } + + if proxyDebugMessages { + log.Printf("Process event from %s: %+v", c.url, event) + } + c.publishersLock.RLock() + publisher, found := c.publishers[event.ClientId] + c.publishersLock.RUnlock() + if found { + publisher.ProcessEvent(event) + return + } + + c.subscribersLock.RLock() + subscriber, found := c.subscribers[event.ClientId] + c.subscribersLock.RUnlock() + if found { + subscriber.ProcessEvent(event) + return + } + + log.Printf("Received event for unknown client %+v from %s", event, c.url) +} + +func (c *mcuProxyConnection) sendHello() error { + c.helloMsgId = strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10) + msg := &ProxyClientMessage{ + Id: c.helloMsgId, + Type: "hello", + Hello: &HelloProxyClientMessage{ + Version: "1.0", + }, + } + if c.sessionId != "" { + msg.Hello.ResumeId = c.sessionId + } else { + claims := &TokenClaims{ + jwt.StandardClaims{ + IssuedAt: time.Now().Unix(), + Issuer: c.proxy.tokenId, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(c.proxy.tokenKey) + if err != nil { + return err + } + + msg.Hello.Token = tokenString + } + return c.sendMessage(msg) +} + +func (c *mcuProxyConnection) sendMessage(msg *ProxyClientMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + return c.sendMessageLocked(msg) +} + +func (c *mcuProxyConnection) sendMessageLocked(msg *ProxyClientMessage) error { + if proxyDebugMessages { + log.Printf("Send message to %s: %+v", c.url, msg) + } + if c.conn == nil { + return ErrNotConnected + } + return c.conn.WriteJSON(msg) +} + +func (c *mcuProxyConnection) performAsyncRequest(ctx context.Context, msg *ProxyClientMessage, callback func(err error, response *ProxyServerMessage)) { + msgId := strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10) + msg.Id = msgId + + c.mu.Lock() + defer c.mu.Unlock() + c.callbacks[msgId] = func(msg *ProxyServerMessage) { + callback(nil, msg) + } + if err := c.sendMessageLocked(msg); err != nil { + delete(c.callbacks, msgId) + go callback(err, nil) + return + } +} + +func (c *mcuProxyConnection) performSyncRequest(ctx context.Context, msg *ProxyClientMessage) (*ProxyServerMessage, error) { + errChan := make(chan error, 1) + responseChan := make(chan *ProxyServerMessage, 1) + c.performAsyncRequest(ctx, msg, func(err error, response *ProxyServerMessage) { + if err != nil { + errChan <- err + } else { + responseChan <- response + } + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-errChan: + return nil, err + case response := <-responseChan: + return response, nil + } +} + +func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "create-publisher", + StreamType: streamType, + }, + } + + response, err := c.performSyncRequest(ctx, msg) + if err != nil { + // TODO: Cancel request + return nil, err + } + + proxyId := response.Command.Id + log.Printf("Created %s publisher %s on %s for %s", streamType, proxyId, c.url, id) + publisher := newMcuProxyPublisher(id, streamType, proxyId, c, listener) + c.publishersLock.Lock() + c.publishers[proxyId] = publisher + c.publisherIds[id+"|"+streamType] = proxyId + c.publishersLock.Unlock() + return publisher, nil +} + +func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { + c.publishersLock.Lock() + id, found := c.publisherIds[publisher+"|"+streamType] + c.publishersLock.Unlock() + if !found { + return nil, fmt.Errorf("Unknown publisher %s", publisher) + } + + msg := &ProxyClientMessage{ + Type: "command", + Command: &CommandProxyClientMessage{ + Type: "create-subscriber", + StreamType: streamType, + PublisherId: id, + }, + } + + response, err := c.performSyncRequest(ctx, msg) + if err != nil { + // TODO: Cancel request + return nil, err + } + + proxyId := response.Command.Id + log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c.url, publisher) + subscriber := newMcuProxySubscriber(publisher, streamType, proxyId, c, listener) + c.subscribersLock.Lock() + c.subscribers[proxyId] = subscriber + c.subscribersLock.Unlock() + return subscriber, nil +} + +type mcuProxy struct { + tokenId string + tokenKey *rsa.PrivateKey + + connections atomic.Value + connRequests int64 + nextSort int64 + + mu sync.RWMutex + publishers map[string]*mcuProxyConnection + + publisherWaitersId uint64 + publisherWaiters map[uint64]chan bool +} + +func NewMcuProxy(baseUrl string, config *goconf.ConfigFile) (Mcu, error) { + var connections []*mcuProxyConnection + + tokenId, _ := config.GetString("mcu", "token_id") + if tokenId == "" { + return nil, fmt.Errorf("No token id configured") + } + tokenKeyFilename, _ := config.GetString("mcu", "token_key") + if tokenKeyFilename == "" { + return nil, fmt.Errorf("No token key configured") + } + tokenKeyData, err := ioutil.ReadFile(tokenKeyFilename) + if err != nil { + return nil, fmt.Errorf("Could not read private key from %s: %s", tokenKeyFilename, err) + } + tokenKey, err := jwt.ParseRSAPrivateKeyFromPEM(tokenKeyData) + if err != nil { + return nil, fmt.Errorf("Could not parse private key from %s: %s", tokenKeyFilename, err) + } + + mcu := &mcuProxy{ + tokenId: tokenId, + tokenKey: tokenKey, + + publishers: make(map[string]*mcuProxyConnection), + + publisherWaiters: make(map[uint64]chan bool), + } + + for _, u := range strings.Split(baseUrl, " ") { + conn, err := newMcuProxyConnection(mcu, u) + if err != nil { + return nil, err + } + + connections = append(connections, conn) + } + if len(connections) == 0 { + return nil, fmt.Errorf("No MCU proxy connections configured") + } + + mcu.setConnections(connections) + return mcu, nil +} + +func (m *mcuProxy) setConnections(connections []*mcuProxyConnection) { + m.connections.Store(connections) +} + +func (m *mcuProxy) getConnections() []*mcuProxyConnection { + return m.connections.Load().([]*mcuProxyConnection) +} + +func (m *mcuProxy) Start() error { + for _, c := range m.getConnections() { + if err := c.start(); err != nil { + return err + } + } + return nil +} + +func (m *mcuProxy) Stop() { + for _, c := range m.getConnections() { + ctx, cancel := context.WithTimeout(context.Background(), closeTimeout) + defer cancel() + c.stop(ctx) + } +} + +func (m *mcuProxy) SetOnConnected(f func()) { + // Not supported. +} + +func (m *mcuProxy) SetOnDisconnected(f func()) { + // Not supported. +} + +type mcuProxyStats struct { + Publishers int64 `json:"publishers"` + Clients int64 `json:"clients"` + Details map[string]*mcuProxyConnectionStats `json:"details"` +} + +func (m *mcuProxy) GetStats() interface{} { + details := make(map[string]*mcuProxyConnectionStats) + result := &mcuProxyStats{ + Details: details, + } + for _, conn := range m.getConnections() { + stats := conn.GetStats() + result.Publishers += stats.Publishers + result.Clients += stats.Clients + details[stats.Url] = stats + } + return result +} + +type mcuProxyConnectionsList []*mcuProxyConnection + +func (l mcuProxyConnectionsList) Len() int { + return len(l) +} + +func (l mcuProxyConnectionsList) Less(i, j int) bool { + return l[i].Load() < l[j].Load() +} + +func (l mcuProxyConnectionsList) Swap(i, j int) { + l[i], l[j] = l[j], l[i] +} + +func (l mcuProxyConnectionsList) Sort() { + sort.Sort(l) +} + +func (m *mcuProxy) getSortedConnections() []*mcuProxyConnection { + connections := m.getConnections() + if len(connections) < 2 { + return connections + } + + // Connections are re-sorted every requests or + // every . + now := time.Now().UnixNano() + if atomic.AddInt64(&m.connRequests, 1)%connectionSortRequests == 0 || atomic.LoadInt64(&m.nextSort) <= now { + atomic.StoreInt64(&m.nextSort, now+int64(connectionSortInterval)) + + sorted := make(mcuProxyConnectionsList, len(connections)) + copy(sorted, connections) + + sorted.Sort() + + m.setConnections(sorted) + connections = sorted + } + + return connections +} + +func (m *mcuProxy) removePublisher(publisher *mcuProxyPublisher) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.publishers, publisher.id+"|"+publisher.StreamType()) +} + +func (m *mcuProxy) wakeupWaiters() { + m.mu.RLock() + defer m.mu.RUnlock() + for _, ch := range m.publisherWaiters { + ch <- true + } +} + +func (m *mcuProxy) addWaiter(ch chan bool) uint64 { + id := m.publisherWaitersId + 1 + m.publisherWaitersId = id + m.publisherWaiters[id] = ch + return id +} + +func (m *mcuProxy) removeWaiter(id uint64) { + delete(m.publisherWaiters, id) +} + +func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { + connections := m.getSortedConnections() + for _, conn := range connections { + publisher, err := conn.newPublisher(ctx, listener, id, streamType) + if err != nil { + log.Printf("Could not create %s publisher for %s on %s: %s", streamType, id, conn.url, err) + continue + } + + m.mu.Lock() + m.publishers[id+"|"+streamType] = conn + m.mu.Unlock() + m.wakeupWaiters() + return publisher, nil + } + + return nil, fmt.Errorf("No MCU connection available") +} + +func (m *mcuProxy) getPublisherConnection(ctx context.Context, publisher string, streamType string) *mcuProxyConnection { + m.mu.RLock() + conn := m.publishers[publisher+"|"+streamType] + m.mu.RUnlock() + if conn != nil { + return conn + } + + log.Printf("No %s publisher %s found yet, deferring", streamType, publisher) + m.mu.Lock() + defer m.mu.Unlock() + + conn = m.publishers[publisher+"|"+streamType] + if conn != nil { + return conn + } + + ch := make(chan bool, 1) + id := m.addWaiter(ch) + defer m.removeWaiter(id) + + for { + m.mu.Unlock() + select { + case <-ch: + m.mu.Lock() + conn = m.publishers[publisher+"|"+streamType] + if conn != nil { + return conn + } + case <-ctx.Done(): + m.mu.Lock() + return nil + } + } +} + +func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) { + conn := m.getPublisherConnection(ctx, publisher, streamType) + if conn == nil { + return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher) + } + + return conn.newSubscriber(ctx, listener, publisher, streamType) +} From b7c258b459d8e7354045ec7b96c662129fef696e Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 7 Aug 2020 11:01:09 +0200 Subject: [PATCH 5/8] Add proxy service implementation. --- .gitignore | 2 + Makefile | 6 +- dependencies.tsv | 1 + proxy.conf.in | 55 ++ src/proxy/main.go | 153 ++++++ src/proxy/proxy_client.go | 55 ++ src/proxy/proxy_server.go | 962 +++++++++++++++++++++++++++++++++ src/proxy/proxy_session.go | 272 ++++++++++ src/signaling/api_signaling.go | 1 + 9 files changed, 1506 insertions(+), 1 deletion(-) create mode 100644 proxy.conf.in create mode 100644 src/proxy/main.go create mode 100644 src/proxy/proxy_client.go create mode 100644 src/proxy/proxy_server.go create mode 100644 src/proxy/proxy_session.go diff --git a/.gitignore b/.gitignore index ae1d4af..7296465 100644 --- a/.gitignore +++ b/.gitignore @@ -2,9 +2,11 @@ bin/ vendor/ *_easyjson.go +*.pem *.prof *.socket *.tar.gz cover.out +proxy.conf server.conf diff --git a/Makefile b/Makefile index 9d1cdcf..ea49861 100644 --- a/Makefile +++ b/Makefile @@ -109,10 +109,14 @@ server: dependencies common mkdir -p $(BINDIR) GOPATH=$(GOPATH) $(GO) build $(BUILDARGS) -ldflags '$(INTERNALLDFLAGS)' -o $(BINDIR)/signaling ./src/server/... +proxy: dependencies common + mkdir -p $(BINDIR) + GOPATH=$(GOPATH) $(GO) build $(BUILDARGS) -ldflags '$(INTERNALLDFLAGS)' -o $(BINDIR)/proxy ./src/proxy/... + clean: rm -f src/signaling/*_easyjson.go -build: server +build: server proxy tarball: git archive \ diff --git a/dependencies.tsv b/dependencies.tsv index 1a865cd..12f699b 100644 --- a/dependencies.tsv +++ b/dependencies.tsv @@ -1,4 +1,5 @@ github.com/dlintw/goconf git dcc070983490608a14480e3bf943bad464785df5 2012-02-28T08:26:10Z +github.com/google/uuid git 0e4e31197428a347842d152773b4cace4645ca25 2020-07-02T18:56:42Z github.com/gorilla/context git 08b5f424b9271eedf6f9f0ce86cb9396ed337a42 2016-08-17T18:46:32Z github.com/gorilla/mux git ac112f7d75a0714af1bd86ab17749b31f7809640 2017-07-04T07:43:45Z github.com/gorilla/securecookie git e59506cc896acb7f7bf732d4fdf5e25f7ccd8983 2017-02-24T19:38:04Z diff --git a/proxy.conf.in b/proxy.conf.in new file mode 100644 index 0000000..4770eda --- /dev/null +++ b/proxy.conf.in @@ -0,0 +1,55 @@ +[http] +# IP and port to listen on for HTTP requests. +# Comment line to disable the listener. +#listen = 127.0.0.1:9090 + +[app] +# Set to "true" to install pprof debug handlers. +# See "https://golang.org/pkg/net/http/pprof/" for further information. +#debug = false + +# ISO 3166 country this proxy is located at. This will be used by the signaling +# servers to determine the closest proxy for publishers. +#country = DE + +[sessions] +# Secret value used to generate checksums of sessions. This should be a random +# string of 32 or 64 bytes. +hashkey = secret-for-session-checksums + +# Optional key for encrypting data in the sessions. Must be either 16, 24 or +# 32 bytes. +# If no key is specified, data will not be encrypted (not recommended). +blockkey = -encryption-key- + +[nats] +# Url of NATS backend to use. This can also be a list of URLs to connect to +# multiple backends. For local development, this can be set to ":loopback:" +# to process NATS messages internally instead of sending them through an +# external NATS backend. +#url = nats://localhost:4222 + +[tokens] +# Mapping of = of signaling servers allowed to connect. +#server1 = pubkey1.pem +#server2 = pubkey2.pem + +[mcu] +# The type of the MCU to use. Currently only "janus" is supported. +type = janus + +# The URL to the websocket endpoint of the MCU server. +url = ws://localhost:8188/ + +# The maximum bitrate per publishing stream (in bits per second). +# Defaults to 1 mbit/sec. +#maxstreambitrate = 1048576 + +# The maximum bitrate per screensharing stream (in bits per second). +# Default is 2 mbit/sec. +#maxscreenbitrate = 2097152 + +[stats] +# Comma-separated list of IP addresses that are allowed to access the stats +# endpoint. Leave empty (or commented) to only allow access from "127.0.0.1". +#allowed_ips = diff --git a/src/proxy/main.go b/src/proxy/main.go new file mode 100644 index 0000000..de3ba1a --- /dev/null +++ b/src/proxy/main.go @@ -0,0 +1,153 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "runtime" + "strings" + "syscall" + "time" + + "github.com/dlintw/goconf" + "github.com/gorilla/mux" + "github.com/nats-io/go-nats" + + "signaling" +) + +var ( + version = "unreleased" + + configFlag = flag.String("config", "proxy.conf", "config file to use") + + showVersion = flag.Bool("version", false, "show version and quit") +) + +const ( + defaultReadTimeout = 15 + defaultWriteTimeout = 15 + + proxyDebugMessages = false +) + +func main() { + log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile) + flag.Parse() + + if *showVersion { + fmt.Printf("nextcloud-spreed-signaling-proxy version %s/%s\n", version, runtime.Version()) + os.Exit(0) + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + signal.Notify(sigChan, syscall.SIGHUP) + + log.Printf("Starting up version %s/%s as pid %d", version, runtime.Version(), os.Getpid()) + + config, err := goconf.ReadConfigFile(*configFlag) + if err != nil { + log.Fatal("Could not read configuration: ", err) + } + + cpus := runtime.NumCPU() + runtime.GOMAXPROCS(cpus) + log.Printf("Using a maximum of %d CPUs\n", cpus) + + natsUrl, _ := config.GetString("nats", "url") + if natsUrl == "" { + natsUrl = nats.DefaultURL + } + + nats, err := signaling.NewNatsClient(natsUrl) + if err != nil { + log.Fatal("Could not create NATS client: ", err) + } + + r := mux.NewRouter() + + proxy, err := NewProxyServer(r, version, config, nats) + if err != nil { + log.Fatal(err) + } + + if err := proxy.Start(config); err != nil { + log.Fatal(err) + } + defer proxy.Stop() + + if addr, _ := config.GetString("http", "listen"); addr != "" { + readTimeout, _ := config.GetInt("http", "readtimeout") + if readTimeout <= 0 { + readTimeout = defaultReadTimeout + } + writeTimeout, _ := config.GetInt("http", "writetimeout") + if writeTimeout <= 0 { + writeTimeout = defaultWriteTimeout + } + + for _, address := range strings.Split(addr, " ") { + go func(address string) { + log.Println("Listening on", address) + listener, err := net.Listen("tcp", address) + if err != nil { + log.Fatal("Could not start listening: ", err) + } + srv := &http.Server{ + Handler: r, + Addr: addr, + + ReadTimeout: time.Duration(readTimeout) * time.Second, + WriteTimeout: time.Duration(writeTimeout) * time.Second, + } + if err := srv.Serve(listener); err != nil { + log.Fatal("Could not start server: ", err) + } + }(address) + } + } + +loop: + for { + switch sig := <-sigChan; sig { + case os.Interrupt: + log.Println("Interrupted") + break loop + case syscall.SIGHUP: + log.Printf("Received SIGHUP, reloading %s", *configFlag) + config, err := goconf.ReadConfigFile(*configFlag) + if err != nil { + log.Printf("Could not read configuration from %s: %s", *configFlag, err) + continue + } + + proxy.Reload(config) + } + } +} diff --git a/src/proxy/proxy_client.go b/src/proxy/proxy_client.go new file mode 100644 index 0000000..5cb8d6a --- /dev/null +++ b/src/proxy/proxy_client.go @@ -0,0 +1,55 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "sync/atomic" + "unsafe" + + "github.com/gorilla/websocket" + + "signaling" +) + +type ProxyClient struct { + signaling.Client + + proxy *ProxyServer + + session unsafe.Pointer +} + +func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) { + client := &ProxyClient{ + proxy: proxy, + } + client.SetConn(conn, addr) + return client, nil +} + +func (c *ProxyClient) GetSession() *ProxySession { + return (*ProxySession)(atomic.LoadPointer(&c.session)) +} + +func (c *ProxyClient) SetSession(session *ProxySession) { + atomic.StorePointer(&c.session, unsafe.Pointer(session)) +} diff --git a/src/proxy/proxy_server.go b/src/proxy/proxy_server.go new file mode 100644 index 0000000..a82a42a --- /dev/null +++ b/src/proxy/proxy_server.go @@ -0,0 +1,962 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "crypto/rsa" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/pprof" + "os" + "os/signal" + runtimepprof "runtime/pprof" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/dlintw/goconf" + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/gorilla/securecookie" + "github.com/gorilla/websocket" + + "golang.org/x/net/context" + + "gopkg.in/dgrijalva/jwt-go.v3" + + "signaling" +) + +const ( + // Buffer sizes when reading/writing websocket connections. + websocketReadBufferSize = 4096 + websocketWriteBufferSize = 4096 + + initialMcuRetry = time.Second + maxMcuRetry = time.Second * 16 + + updateLoadInterval = time.Second + expireSessionsInterval = 10 * time.Second + + // Maximum age a token may have to prevent reuse of old tokens. + maxTokenAge = 5 * time.Minute +) + +type ContextKey string + +var ( + ContextKeySession = ContextKey("session") + + TimeoutCreatingPublisher = signaling.NewError("timeout", "Timeout creating publisher.") + TimeoutCreatingSubscriber = signaling.NewError("timeout", "Timeout creating subscriber.") + TokenAuthFailed = signaling.NewError("auth_failed", "The token could not be authenticated.") + TokenExpired = signaling.NewError("token_expired", "The token is expired.") + UnknownClient = signaling.NewError("unknown_client", "Unknown client id given.") + UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.") + UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.") + UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.") +) + +type ProxyServer struct { + version string + country string + + url string + nats signaling.NatsClient + mcu signaling.Mcu + stopped uint32 + load int64 + + upgrader websocket.Upgrader + + tokenKeys atomic.Value + statsAllowedIps map[string]bool + + sid uint64 + cookie *securecookie.SecureCookie + sessions map[uint64]*ProxySession + sessionsLock sync.RWMutex + + clients map[string]signaling.McuClient + clientIds map[string]string + clientsLock sync.RWMutex +} + +func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile, nats signaling.NatsClient) (*ProxyServer, error) { + hashKey, _ := config.GetString("sessions", "hashkey") + switch len(hashKey) { + case 32: + case 64: + default: + log.Printf("WARNING: The sessions hash key should be 32 or 64 bytes but is %d bytes", len(hashKey)) + } + + blockKey, _ := config.GetString("sessions", "blockkey") + blockBytes := []byte(blockKey) + switch len(blockKey) { + case 0: + blockBytes = nil + case 16: + case 24: + case 32: + default: + return nil, fmt.Errorf("The sessions block key must be 16, 24 or 32 bytes but is %d bytes", len(blockKey)) + } + + tokenKeys := make(map[string]*rsa.PublicKey) + options, _ := config.GetOptions("tokens") + for _, id := range options { + filename, _ := config.GetString("tokens", id) + if filename == "" { + return nil, fmt.Errorf("No filename given for token %s", id) + } + + keyData, err := ioutil.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("Could not read public key from %s: %s", filename, err) + } + key, err := jwt.ParseRSAPublicKeyFromPEM(keyData) + if err != nil { + return nil, fmt.Errorf("Could not parse public key from %s: %s", filename, err) + } + + tokenKeys[id] = key + } + + var keyIds []string + for k, _ := range tokenKeys { + keyIds = append(keyIds, k) + } + sort.Strings(keyIds) + log.Printf("Enabled token keys: %v", keyIds) + + statsAllowed, _ := config.GetString("stats", "allowed_ips") + var statsAllowedIps map[string]bool + if statsAllowed == "" { + log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1") + statsAllowedIps = map[string]bool{ + "127.0.0.1": true, + } + } else { + log.Printf("Only allowing access to the stats endpoing from %s", statsAllowed) + statsAllowedIps = make(map[string]bool) + for _, ip := range strings.Split(statsAllowed, ",") { + ip = strings.TrimSpace(ip) + if ip != "" { + statsAllowedIps[ip] = true + } + } + } + + country, _ := config.GetString("app", "country") + if country != "" { + log.Printf("Sending %s as country information", country) + } else { + log.Printf("Not sending country information") + } + + result := &ProxyServer{ + version: version, + country: country, + + nats: nats, + + upgrader: websocket.Upgrader{ + ReadBufferSize: websocketReadBufferSize, + WriteBufferSize: websocketWriteBufferSize, + }, + + statsAllowedIps: statsAllowedIps, + + cookie: securecookie.New([]byte(hashKey), blockBytes).MaxAge(0), + sessions: make(map[uint64]*ProxySession), + + clients: make(map[string]signaling.McuClient), + clientIds: make(map[string]string), + } + + result.setTokenKeys(tokenKeys) + result.upgrader.CheckOrigin = result.checkOrigin + + if debug, _ := config.GetBool("app", "debug"); debug { + log.Println("Installing debug handlers in \"/debug/pprof\"") + r.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) + r.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + r.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) + r.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) + r.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) + for _, profile := range runtimepprof.Profiles() { + name := profile.Name() + r.Handle("/debug/pprof/"+name, pprof.Handler(name)) + } + } + + r.HandleFunc("/proxy", result.setCommonHeaders(result.proxyHandler)).Methods("GET") + r.HandleFunc("/stats", result.setCommonHeaders(result.validateStatsRequest(result.statsHandler))).Methods("GET") + return result, nil +} + +func (s *ProxyServer) checkOrigin(r *http.Request) bool { + // We allow any Origin to connect to the service. + return true +} + +func (s *ProxyServer) setTokenKeys(keys map[string]*rsa.PublicKey) { + s.tokenKeys.Store(keys) +} + +func (s *ProxyServer) getTokenKeys() map[string]*rsa.PublicKey { + return s.tokenKeys.Load().(map[string]*rsa.PublicKey) +} + +func (s *ProxyServer) Start(config *goconf.ConfigFile) error { + s.url, _ = config.GetString("mcu", "url") + if s.url == "" { + return fmt.Errorf("No MCU server url configured") + } + + mcuType, _ := config.GetString("mcu", "type") + if mcuType == "" { + mcuType = signaling.McuTypeDefault + } + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + defer signal.Stop(interrupt) + + var err error + var mcu signaling.Mcu + mcuRetry := initialMcuRetry + mcuRetryTimer := time.NewTimer(mcuRetry) + for { + switch mcuType { + case signaling.McuTypeJanus: + mcu, err = signaling.NewMcuJanus(s.url, config, s.nats) + default: + return fmt.Errorf("Unsupported MCU type: %s", mcuType) + } + if err == nil { + mcu.SetOnConnected(s.onMcuConnected) + mcu.SetOnDisconnected(s.onMcuDisconnected) + err = mcu.Start() + if err != nil { + log.Printf("Could not create %s MCU at %s: %s", mcuType, s.url, err) + } + } + if err == nil { + break + } + + log.Printf("Could not initialize %s MCU at %s (%s) will retry in %s", mcuType, s.url, err, mcuRetry) + mcuRetryTimer.Reset(mcuRetry) + select { + case <-interrupt: + return fmt.Errorf("Cancelled") + case <-mcuRetryTimer.C: + // Retry connection + mcuRetry = mcuRetry * 2 + if mcuRetry > maxMcuRetry { + mcuRetry = maxMcuRetry + } + } + } + + s.mcu = mcu + + go s.run() + + return nil +} + +func (s *ProxyServer) run() { + updateLoadTicker := time.NewTicker(updateLoadInterval) + expireSessionsTicker := time.NewTicker(expireSessionsInterval) +loop: + for { + select { + case <-updateLoadTicker.C: + if atomic.LoadUint32(&s.stopped) != 0 { + break loop + } + s.updateLoad() + case <-expireSessionsTicker.C: + if atomic.LoadUint32(&s.stopped) != 0 { + break loop + } + s.expireSessions() + } + } +} + +func (s *ProxyServer) updateLoad() { + // TODO: Take maximum bandwidth of clients into account when calculating + // load (screensharing requires more than regular audio/video). + load := s.GetClientCount() + if load == atomic.LoadInt64(&s.load) { + return + } + + atomic.StoreInt64(&s.load, load) + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "update-load", + Load: load, + }, + } + + s.IterateSessions(func(session *ProxySession) { + session.sendMessage(msg) + }) +} + +func (s *ProxyServer) getExpiredSessions() []*ProxySession { + var expired []*ProxySession + s.IterateSessions(func(session *ProxySession) { + if session.IsExpired() { + expired = append(expired, session) + } + }) + return expired +} + +func (s *ProxyServer) expireSessions() { + expired := s.getExpiredSessions() + if len(expired) == 0 { + return + } + + s.sessionsLock.Lock() + defer s.sessionsLock.Unlock() + for _, session := range expired { + if !session.IsExpired() { + // Session was used while waiting for the lock. + continue + } + + log.Printf("Delete expired session %s", session.PublicId()) + s.deleteSessionLocked(session.Sid()) + } +} + +func (s *ProxyServer) Stop() { + if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { + return + } + + s.mcu.Stop() +} + +func (s *ProxyServer) Reload(config *goconf.ConfigFile) { + tokenKeys := make(map[string]*rsa.PublicKey) + options, _ := config.GetOptions("tokens") + for _, id := range options { + filename, _ := config.GetString("tokens", id) + if filename == "" { + log.Printf("No filename given for token %s, ignoring", id) + continue + } + + keyData, err := ioutil.ReadFile(filename) + if err != nil { + log.Printf("Could not read public key from %s, ignoring: %s", filename, err) + continue + } + key, err := jwt.ParseRSAPublicKeyFromPEM(keyData) + if err != nil { + log.Printf("Could not parse public key from %s, ignoring: %s", filename, err) + continue + } + + tokenKeys[id] = key + } + + if len(tokenKeys) == 0 { + log.Printf("No token keys loaded") + } else { + var keyIds []string + for k, _ := range tokenKeys { + keyIds = append(keyIds, k) + } + sort.Strings(keyIds) + log.Printf("Enabled token keys: %v", keyIds) + } + s.setTokenKeys(tokenKeys) +} + +func (s *ProxyServer) setCommonHeaders(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Server", "nextcloud-spreed-signaling-proxy/"+s.version) + f(w, r) + } +} + +func getRealUserIP(r *http.Request) string { + // Note this function assumes it is running behind a trusted proxy, so + // the headers can be trusted. + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + // Result could be a list "clientip, proxy1, proxy2", so only use first element. + if pos := strings.Index(ip, ","); pos >= 0 { + ip = strings.TrimSpace(ip[:pos]) + } + return ip + } + + return r.RemoteAddr +} + +func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { + addr := getRealUserIP(r) + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Could not upgrade request from %s: %s", addr, err) + return + } + + client, err := NewProxyClient(s, conn, addr) + if err != nil { + log.Printf("Could not create client for %s: %s", addr, err) + return + } + + client.OnClosed = s.clientClosed + client.OnMessageReceived = func(c *signaling.Client, data []byte) { + s.processMessage(client, data) + } + client.OnRTTReceived = func(c *signaling.Client, rtt time.Duration) { + if session := client.GetSession(); session != nil { + session.MarkUsed() + } + } + + go client.WritePump() + go client.ReadPump() +} + +func (s *ProxyServer) clientClosed(client *signaling.Client) { + log.Printf("Connection from %s closed", client.RemoteAddr()) +} + +func (s *ProxyServer) onMcuConnected() { + log.Printf("Connection to %s established", s.url) + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "backend-connected", + }, + } + + s.IterateSessions(func(session *ProxySession) { + session.sendMessage(msg) + }) +} + +func (s *ProxyServer) onMcuDisconnected() { + if atomic.LoadUint32(&s.stopped) != 0 { + // Shutting down, no need to notify. + return + } + + log.Printf("Connection to %s lost", s.url) + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "backend-disconnected", + }, + } + + s.IterateSessions(func(session *ProxySession) { + session.sendMessage(msg) + session.NotifyDisconnected() + }) +} + +func (s *ProxyServer) sendCurrentLoad(session *ProxySession) { + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "update-load", + Load: atomic.LoadInt64(&s.load), + }, + } + session.sendMessage(msg) +} + +func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { + if proxyDebugMessages { + log.Printf("Message: %s", string(data)) + } + var message signaling.ProxyClientMessage + if err := message.UnmarshalJSON(data); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) + } else { + log.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) + } + client.SendError(signaling.InvalidFormat) + return + } + + if err := message.CheckValid(); err != nil { + if session := client.GetSession(); session != nil { + log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + } else { + log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) + } + client.SendMessage(message.NewErrorServerMessage(signaling.InvalidFormat)) + return + } + + session := client.GetSession() + if session == nil { + if message.Type != "hello" { + client.SendMessage(message.NewErrorServerMessage(signaling.HelloExpected)) + return + } + + var session *ProxySession + if resumeId := message.Hello.ResumeId; resumeId != "" { + var data signaling.SessionIdData + if s.cookie.Decode("session", resumeId, &data) == nil { + session = s.GetSession(data.Sid) + } + + if session == nil { + client.SendMessage(message.NewErrorServerMessage(signaling.NoSuchSession)) + return + } + + log.Printf("Resumed session %s", session.PublicId()) + s.sendCurrentLoad(session) + } else { + var err error + if session, err = s.NewSession(message.Hello); err != nil { + if e, ok := err.(*signaling.Error); ok { + client.SendMessage(message.NewErrorServerMessage(e)) + } else { + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + } + return + } + } + + prev := session.SetClient(client) + if prev != nil { + msg := &signaling.ProxyServerMessage{ + Type: "bye", + Bye: &signaling.ByeProxyServerMessage{ + Reason: "session_resumed", + }, + } + prev.SendMessage(msg) + } + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "hello", + Hello: &signaling.HelloProxyServerMessage{ + Version: signaling.HelloVersion, + SessionId: session.PublicId(), + Server: &signaling.HelloServerMessageServer{ + Version: s.version, + Country: s.country, + }, + }, + } + client.SendMessage(response) + s.sendCurrentLoad(session) + return + } + + ctx := context.WithValue(context.Background(), ContextKeySession, session) + session.MarkUsed() + + switch message.Type { + case "command": + s.processCommand(ctx, client, session, &message) + case "payload": + s.processPayload(ctx, client, session, &message) + default: + session.sendMessage(message.NewErrorServerMessage(UnsupportedMessage)) + } +} + +func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { + cmd := message.Command + switch cmd.Type { + case "create-publisher": + id := uuid.New().String() + publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType) + if err == context.DeadlineExceeded { + log.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId()) + session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingPublisher)) + return + } else if err != nil { + log.Printf("Error while creating %s publisher %s for %s: %s", cmd.StreamType, id, session.PublicId(), err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + + log.Printf("Created %s publisher %s as %s", cmd.StreamType, publisher.Id(), id) + session.StorePublisher(ctx, id, publisher) + s.StoreClient(id, publisher) + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: id, + }, + } + session.sendMessage(response) + case "create-subscriber": + id := uuid.New().String() + publisherId := cmd.PublisherId + subscriber, err := s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType) + if err == context.DeadlineExceeded { + log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId()) + session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber)) + return + } else if err != nil { + log.Printf("Error while creating %s subscriber on %s for %s: %s", cmd.StreamType, publisherId, session.PublicId(), err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + + log.Printf("Created %s subscriber %s as %s", cmd.StreamType, subscriber.Id(), id) + session.StoreSubscriber(ctx, id, subscriber) + s.StoreClient(id, subscriber) + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: id, + }, + } + session.sendMessage(response) + case "delete-publisher": + client := s.GetClient(cmd.ClientId) + if client == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + if session.DeletePublisher(client) == "" { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + s.DeleteClient(cmd.ClientId, client) + + go func() { + log.Printf("Closing %s publisher %s as %s", client.StreamType(), client.Id(), cmd.ClientId) + client.Close(context.Background()) + }() + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: cmd.ClientId, + }, + } + session.sendMessage(response) + case "delete-subscriber": + client := s.GetClient(cmd.ClientId) + if client == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + subscriber, ok := client.(signaling.McuSubscriber) + if !ok { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + if session.DeleteSubscriber(subscriber) == "" { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + s.DeleteClient(cmd.ClientId, client) + + go func() { + log.Printf("Closing %s subscriber %s as %s", client.StreamType(), client.Id(), cmd.ClientId) + client.Close(context.Background()) + }() + + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: cmd.ClientId, + }, + } + session.sendMessage(response) + default: + log.Printf("Unsupported command %+v", message.Command) + session.sendMessage(message.NewErrorServerMessage(UnsupportedCommand)) + } +} + +func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { + payload := message.Payload + mcuClient := s.GetClient(payload.ClientId) + if mcuClient == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + var mcuData *signaling.MessageClientMessageData + switch payload.Type { + case "offer": + fallthrough + case "answer": + fallthrough + case "candidate": + mcuData = &signaling.MessageClientMessageData{ + Type: payload.Type, + Payload: payload.Payload, + } + case "endOfCandidates": + // Ignore but confirm, not passed along to Janus anyway. + session.sendMessage(&signaling.ProxyServerMessage{ + Id: message.Id, + Type: "payload", + Payload: &signaling.PayloadProxyServerMessage{ + Type: payload.Type, + ClientId: payload.ClientId, + }, + }) + return + case "requestoffer": + fallthrough + case "sendoffer": + mcuData = &signaling.MessageClientMessageData{ + Type: payload.Type, + } + default: + session.sendMessage(message.NewErrorServerMessage(UnsupportedPayload)) + return + } + + mcuClient.SendMessage(ctx, nil, mcuData, func(err error, response map[string]interface{}) { + var responseMsg *signaling.ProxyServerMessage + if err != nil { + log.Printf("Error sending %s to %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err) + responseMsg = message.NewWrappedErrorServerMessage(err) + } else { + responseMsg = &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "payload", + Payload: &signaling.PayloadProxyServerMessage{ + Type: payload.Type, + ClientId: payload.ClientId, + Payload: response, + }, + } + } + + session.sendMessage(responseMsg) + }) +} + +func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*ProxySession, error) { + if proxyDebugMessages { + log.Printf("Hello: %+v", hello) + } + + token, err := jwt.ParseWithClaims(hello.Token, &signaling.TokenClaims{}, func(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + log.Printf("Unexpected signing method: %v", token.Header["alg"]) + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + claims, ok := token.Claims.(*signaling.TokenClaims) + if !ok { + log.Printf("Unsupported claims type: %+v", token.Claims) + return nil, fmt.Errorf("Unsupported claims type") + } + + tokenKeys := s.getTokenKeys() + publicKey := tokenKeys[claims.Issuer] + if publicKey == nil { + log.Printf("Issuer %s is not supported", claims.Issuer) + return nil, fmt.Errorf("No key found for issuer") + } + return publicKey, nil + }) + if err != nil { + return nil, TokenAuthFailed + } + + claims, ok := token.Claims.(*signaling.TokenClaims) + if !ok || !token.Valid { + return nil, TokenAuthFailed + } + + minIssuedAt := time.Now().Add(-maxTokenAge) + if issuedAt := time.Unix(claims.IssuedAt, 0); issuedAt.Before(minIssuedAt) { + return nil, TokenExpired + } + + sid := atomic.AddUint64(&s.sid, 1) + for sid == 0 { + sid = atomic.AddUint64(&s.sid, 1) + } + + sessionIdData := &signaling.SessionIdData{ + Sid: sid, + Created: time.Now(), + } + + encoded, err := s.cookie.Encode("session", sessionIdData) + if err != nil { + return nil, err + } + + log.Printf("Created session %s for %+v", encoded, claims) + session := NewProxySession(s, sid, encoded) + s.StoreSession(sid, session) + return session, nil +} + +func (s *ProxyServer) StoreSession(id uint64, session *ProxySession) { + s.sessionsLock.Lock() + defer s.sessionsLock.Unlock() + s.sessions[id] = session +} + +func (s *ProxyServer) GetSession(id uint64) *ProxySession { + s.sessionsLock.RLock() + defer s.sessionsLock.RUnlock() + return s.sessions[id] +} + +func (s *ProxyServer) GetSessionsCount() int64 { + s.sessionsLock.RLock() + defer s.sessionsLock.RUnlock() + return int64(len(s.sessions)) +} + +func (s *ProxyServer) IterateSessions(f func(*ProxySession)) { + s.sessionsLock.RLock() + defer s.sessionsLock.RUnlock() + + for _, session := range s.sessions { + f(session) + } +} + +func (s *ProxyServer) DeleteSession(id uint64) { + s.sessionsLock.Lock() + defer s.sessionsLock.Unlock() + s.deleteSessionLocked(id) +} + +func (s *ProxyServer) deleteSessionLocked(id uint64) { + delete(s.sessions, id) +} + +func (s *ProxyServer) StoreClient(id string, client signaling.McuClient) { + s.clientsLock.Lock() + defer s.clientsLock.Unlock() + s.clients[id] = client + s.clientIds[client.Id()] = id +} + +func (s *ProxyServer) DeleteClient(id string, client signaling.McuClient) { + s.clientsLock.Lock() + defer s.clientsLock.Unlock() + delete(s.clients, id) + delete(s.clientIds, client.Id()) +} + +func (s *ProxyServer) GetClientCount() int64 { + s.clientsLock.RLock() + defer s.clientsLock.RUnlock() + return int64(len(s.clients)) +} + +func (s *ProxyServer) GetClient(id string) signaling.McuClient { + s.clientsLock.RLock() + defer s.clientsLock.RUnlock() + return s.clients[id] +} + +func (s *ProxyServer) GetClientId(client signaling.McuClient) string { + s.clientsLock.RLock() + defer s.clientsLock.RUnlock() + return s.clientIds[client.Id()] +} + +func (s *ProxyServer) getStats() map[string]interface{} { + result := map[string]interface{}{ + "sessions": s.GetSessionsCount(), + "mcu": s.mcu.GetStats(), + } + return result +} + +func (s *ProxyServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + addr := getRealUserIP(r) + if strings.Contains(addr, ":") { + if host, _, err := net.SplitHostPort(addr); err == nil { + addr = host + } + } + if !s.statsAllowedIps[addr] { + http.Error(w, "Authentication check failed", http.StatusForbidden) + return + } + + f(w, r) + } +} + +func (s *ProxyServer) statsHandler(w http.ResponseWriter, r *http.Request) { + stats := s.getStats() + statsData, err := json.MarshalIndent(stats, "", " ") + if err != nil { + log.Printf("Could not serialize stats %+v: %s", stats, err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusOK) + w.Write(statsData) +} diff --git a/src/proxy/proxy_session.go b/src/proxy/proxy_session.go new file mode 100644 index 0000000..ebbb85d --- /dev/null +++ b/src/proxy/proxy_session.go @@ -0,0 +1,272 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "log" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/context" + + "signaling" +) + +const ( + // Sessions expire if they have not been used for one minute. + sessionExpirationTime = time.Minute +) + +type ProxySession struct { + proxy *ProxyServer + id string + sid uint64 + lastUsed int64 + + clientLock sync.Mutex + client *ProxyClient + pendingMessages []*signaling.ProxyServerMessage + + publishersLock sync.Mutex + publishers map[string]signaling.McuPublisher + publisherIds map[string]string + + subscribersLock sync.Mutex + subscribers map[string]signaling.McuSubscriber + subscriberIds map[string]string +} + +func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { + return &ProxySession{ + proxy: proxy, + id: id, + sid: sid, + lastUsed: time.Now().UnixNano(), + + publishers: make(map[string]signaling.McuPublisher), + publisherIds: make(map[string]string), + + subscribers: make(map[string]signaling.McuSubscriber), + subscriberIds: make(map[string]string), + } +} + +func (s *ProxySession) PublicId() string { + return s.id +} + +func (s *ProxySession) Sid() uint64 { + return s.sid +} + +func (s *ProxySession) LastUsed() time.Time { + lastUsed := atomic.LoadInt64(&s.lastUsed) + return time.Unix(0, lastUsed) +} + +func (s *ProxySession) IsExpired() bool { + expiresAt := s.LastUsed().Add(sessionExpirationTime) + return expiresAt.Before(time.Now()) +} + +func (s *ProxySession) MarkUsed() { + now := time.Now() + atomic.StoreInt64(&s.lastUsed, now.UnixNano()) +} + +func (s *ProxySession) SetClient(client *ProxyClient) *ProxyClient { + s.clientLock.Lock() + prev := s.client + s.client = client + var messages []*signaling.ProxyServerMessage + if client != nil { + messages, s.pendingMessages = s.pendingMessages, nil + } + s.clientLock.Unlock() + if prev != nil { + prev.SetSession(nil) + } + if client != nil { + s.MarkUsed() + client.SetSession(s) + for _, msg := range messages { + client.SendMessage(msg) + } + } + return prev +} + +func (s *ProxySession) OnIceCandidate(client signaling.McuClient, candidate interface{}) { + id := s.proxy.GetClientId(client) + if id == "" { + log.Printf("Received candidate %+v from unknown %s client %s (%+v)", candidate, client.StreamType(), client.Id(), client) + return + } + + msg := &signaling.ProxyServerMessage{ + Type: "payload", + Payload: &signaling.PayloadProxyServerMessage{ + Type: "candidate", + ClientId: id, + Payload: map[string]interface{}{ + "candidate": candidate, + }, + }, + } + s.sendMessage(msg) +} + +func (s *ProxySession) sendMessage(message *signaling.ProxyServerMessage) { + var client *ProxyClient + s.clientLock.Lock() + client = s.client + if client == nil { + s.pendingMessages = append(s.pendingMessages, message) + } + s.clientLock.Unlock() + if client != nil { + client.SendMessage(message) + } +} + +func (s *ProxySession) OnIceCompleted(client signaling.McuClient) { + id := s.proxy.GetClientId(client) + if id == "" { + log.Printf("Received ice completed event from unknown %s client %s (%+v)", client.StreamType(), client.Id(), client) + return + } + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "ice-completed", + ClientId: id, + }, + } + s.sendMessage(msg) +} + +func (s *ProxySession) PublisherClosed(publisher signaling.McuPublisher) { + if id := s.DeletePublisher(publisher); id != "" { + s.proxy.DeleteClient(id, publisher) + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "publisher-closed", + ClientId: id, + }, + } + s.sendMessage(msg) + } +} + +func (s *ProxySession) SubscriberClosed(subscriber signaling.McuSubscriber) { + if id := s.DeleteSubscriber(subscriber); id != "" { + s.proxy.DeleteClient(id, subscriber) + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "subscriber-closed", + ClientId: id, + }, + } + s.sendMessage(msg) + } +} + +func (s *ProxySession) StorePublisher(ctx context.Context, id string, publisher signaling.McuPublisher) { + s.publishersLock.Lock() + defer s.publishersLock.Unlock() + + s.publishers[id] = publisher + s.publisherIds[publisher.Id()] = id +} + +func (s *ProxySession) DeletePublisher(publisher signaling.McuPublisher) string { + s.publishersLock.Lock() + defer s.publishersLock.Unlock() + + id, found := s.publisherIds[publisher.Id()] + if !found { + return "" + } + + delete(s.publishers, id) + delete(s.publisherIds, publisher.Id()) + return id +} + +func (s *ProxySession) StoreSubscriber(ctx context.Context, id string, subscriber signaling.McuSubscriber) { + s.subscribersLock.Lock() + defer s.subscribersLock.Unlock() + + s.subscribers[id] = subscriber + s.subscriberIds[subscriber.Id()] = id +} + +func (s *ProxySession) DeleteSubscriber(subscriber signaling.McuSubscriber) string { + s.subscribersLock.Lock() + defer s.subscribersLock.Unlock() + + id, found := s.subscriberIds[subscriber.Id()] + if !found { + return "" + } + + delete(s.subscribers, id) + delete(s.subscriberIds, subscriber.Id()) + return id +} + +func (s *ProxySession) clearPublishers() { + s.publishersLock.Lock() + defer s.publishersLock.Unlock() + + go func(publishers map[string]signaling.McuPublisher) { + for _, publisher := range publishers { + publisher.Close(context.Background()) + } + }(s.publishers) + s.publishers = make(map[string]signaling.McuPublisher) + s.publisherIds = make(map[string]string) +} + +func (s *ProxySession) clearSubscribers() { + s.publishersLock.Lock() + defer s.publishersLock.Unlock() + + go func(subscribers map[string]signaling.McuSubscriber) { + for _, subscriber := range subscribers { + subscriber.Close(context.Background()) + } + }(s.subscribers) + s.subscribers = make(map[string]signaling.McuSubscriber) + s.subscriberIds = make(map[string]string) +} + +func (s *ProxySession) NotifyDisconnected() { + s.clearPublishers() + s.clearSubscribers() +} diff --git a/src/signaling/api_signaling.go b/src/signaling/api_signaling.go index 5f43969..5762459 100644 --- a/src/signaling/api_signaling.go +++ b/src/signaling/api_signaling.go @@ -278,6 +278,7 @@ const ( type HelloServerMessageServer struct { Version string `json:"version"` Features []string `json:"features,omitempty"` + Country string `json:"country,omitempty"` } type HelloServerMessage struct { From 2626b1ac6c50ae757eeefd282bee0e1b11fe716d Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 7 Aug 2020 11:47:10 +0200 Subject: [PATCH 6/8] Add note on using multiple Janus servers. --- README.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/README.md b/README.md index 358ab23..095c2f6 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,31 @@ The maximum bandwidth per publishing stream can also be configured in the section `[mcu]`, see properties `maxstreambitrate` and `maxscreenbitrate`. +### Use multiple Janus servers + +To scale the setup and add high availability, a signaling server can connect to +one or multiple proxy servers that each provide access to a single Janus server. + +For that, set the `type` key in section `[mcu]` to `proxy` and set `url` to a +space-separated list of URLs where a proxy server is running. + +Each signaling server that connects to a proxy needs a unique token id and a +public / private RSA keypair. The token id must be configured as `token_id` in +section `[mcu]`, the path to the private key file as `token_key`. + + +### Setup of proxy server + +The proxy server is built with the standard make command `make build` as +`bin/proxy` binary. Copy the `proxy.conf.in` as `proxy.conf` and edit section +`[tokens]` to the list of allowed token ids and filenames of the public keys +for each token id. See the comments in `proxy.conf.in` for other configuration +options. + +When the proxy process receives a `SIGHUP` signal, the list of allowed token +ids / public keys is reloaded. + + ## Setup of frontend webserver Usually the standalone signaling server is running behind a webserver that does From d9d11b58e129d53cd62791d529895204489ac0c6 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Fri, 7 Aug 2020 15:16:13 +0200 Subject: [PATCH 7/8] Shutdown a proxy gracefully on SIGUSR1. No new publishers will be created by the proxy, existing publishers can still be subscribed. After all clients have disconnected, the process will terminate. --- README.md | 4 ++- src/proxy/main.go | 32 ++++++++++------- src/proxy/proxy_server.go | 70 ++++++++++++++++++++++++++++++++++++-- src/signaling/mcu_proxy.go | 21 ++++++++++-- 4 files changed, 109 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 095c2f6..785c34a 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,9 @@ for each token id. See the comments in `proxy.conf.in` for other configuration options. When the proxy process receives a `SIGHUP` signal, the list of allowed token -ids / public keys is reloaded. +ids / public keys is reloaded. A `SIGUSR1` signal can be used to shutdown a +proxy process gracefully after all clients have been disconnected. No new +publishers will be accepted in this case. ## Setup of frontend webserver diff --git a/src/proxy/main.go b/src/proxy/main.go index de3ba1a..e3505e8 100644 --- a/src/proxy/main.go +++ b/src/proxy/main.go @@ -68,6 +68,7 @@ func main() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt) signal.Notify(sigChan, syscall.SIGHUP) + signal.Notify(sigChan, syscall.SIGUSR1) log.Printf("Starting up version %s/%s as pid %d", version, runtime.Version(), os.Getpid()) @@ -135,19 +136,26 @@ func main() { loop: for { - switch sig := <-sigChan; sig { - case os.Interrupt: - log.Println("Interrupted") - break loop - case syscall.SIGHUP: - log.Printf("Received SIGHUP, reloading %s", *configFlag) - config, err := goconf.ReadConfigFile(*configFlag) - if err != nil { - log.Printf("Could not read configuration from %s: %s", *configFlag, err) - continue + select { + case sig := <-sigChan: + switch sig { + case os.Interrupt: + log.Println("Interrupted") + break loop + case syscall.SIGHUP: + log.Printf("Received SIGHUP, reloading %s", *configFlag) + if config, err := goconf.ReadConfigFile(*configFlag); err != nil { + log.Printf("Could not read configuration from %s: %s", *configFlag, err) + } else { + proxy.Reload(config) + } + case syscall.SIGUSR1: + log.Printf("Received SIGUSR1, scheduling server to shutdown") + proxy.ScheduleShutdown() } - - proxy.Reload(config) + case <-proxy.ShutdownChannel(): + log.Printf("All clients disconnected, shutting down") + break loop } } } diff --git a/src/proxy/proxy_server.go b/src/proxy/proxy_server.go index a82a42a..802c916 100644 --- a/src/proxy/proxy_server.go +++ b/src/proxy/proxy_server.go @@ -80,6 +80,7 @@ var ( UnsupportedCommand = signaling.NewError("bad_request", "Unsupported command received.") UnsupportedMessage = signaling.NewError("bad_request", "Unsupported message received.") UnsupportedPayload = signaling.NewError("unsupported_payload", "Unsupported payload type.") + ShutdownScheduled = signaling.NewError("shutdown_scheduled", "The server is scheduled to shutdown.") ) type ProxyServer struct { @@ -92,6 +93,9 @@ type ProxyServer struct { stopped uint32 load int64 + shutdownChannel chan bool + shutdownScheduled uint32 + upgrader websocket.Upgrader tokenKeys atomic.Value @@ -186,6 +190,8 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile, na nats: nats, + shutdownChannel: make(chan bool, 1), + upgrader: websocket.Upgrader{ ReadBufferSize: websocketReadBufferSize, WriteBufferSize: websocketWriteBufferSize, @@ -322,6 +328,11 @@ func (s *ProxyServer) updateLoad() { } atomic.StoreInt64(&s.load, load) + if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + // Server is scheduled to shutdown, no need to update clients with current load. + return + } + msg := &signaling.ProxyServerMessage{ Type: "event", Event: &signaling.EventProxyServerMessage{ @@ -372,6 +383,32 @@ func (s *ProxyServer) Stop() { s.mcu.Stop() } +func (s *ProxyServer) ShutdownChannel() chan bool { + return s.shutdownChannel +} + +func (s *ProxyServer) ScheduleShutdown() { + if !atomic.CompareAndSwapUint32(&s.shutdownScheduled, 0, 1) { + return + } + + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "shutdown-scheduled", + }, + } + s.IterateSessions(func(session *ProxySession) { + session.sendMessage(msg) + }) + + if s.GetClientCount() == 0 { + go func() { + s.shutdownChannel <- true + }() + } +} + func (s *ProxyServer) Reload(config *goconf.ConfigFile) { tokenKeys := make(map[string]*rsa.PublicKey) options, _ := config.GetOptions("tokens") @@ -511,6 +548,16 @@ func (s *ProxyServer) sendCurrentLoad(session *ProxySession) { session.sendMessage(msg) } +func (s *ProxyServer) sendShutdownScheduled(session *ProxySession) { + msg := &signaling.ProxyServerMessage{ + Type: "event", + Event: &signaling.EventProxyServerMessage{ + Type: "shutdown-scheduled", + }, + } + session.sendMessage(msg) +} + func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { if proxyDebugMessages { log.Printf("Message: %s", string(data)) @@ -556,7 +603,11 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { } log.Printf("Resumed session %s", session.PublicId()) - s.sendCurrentLoad(session) + if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + s.sendShutdownScheduled(session) + } else { + s.sendCurrentLoad(session) + } } else { var err error if session, err = s.NewSession(message.Hello); err != nil { @@ -592,7 +643,11 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { }, } client.SendMessage(response) - s.sendCurrentLoad(session) + if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + s.sendShutdownScheduled(session) + } else { + s.sendCurrentLoad(session) + } return } @@ -613,6 +668,11 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s cmd := message.Command switch cmd.Type { case "create-publisher": + if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + session.sendMessage(message.NewErrorServerMessage(ShutdownScheduled)) + return + } + id := uuid.New().String() publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType) if err == context.DeadlineExceeded { @@ -901,6 +961,12 @@ func (s *ProxyServer) DeleteClient(id string, client signaling.McuClient) { defer s.clientsLock.Unlock() delete(s.clients, id) delete(s.clientIds, client.Id()) + + if len(s.clients) == 0 && atomic.LoadUint32(&s.shutdownScheduled) != 0 { + go func() { + s.shutdownChannel <- true + }() + } } func (s *ProxyServer) GetClientCount() int64 { diff --git a/src/signaling/mcu_proxy.go b/src/signaling/mcu_proxy.go index f5c0bbb..032d1f1 100644 --- a/src/signaling/mcu_proxy.go +++ b/src/signaling/mcu_proxy.go @@ -254,6 +254,7 @@ type mcuProxyConnection struct { connectedSince time.Time reconnectInterval int64 reconnectTimer *time.Timer + shutdownScheduled uint32 msgId int64 helloMsgId string @@ -323,6 +324,10 @@ func (c *mcuProxyConnection) Load() int64 { return atomic.LoadInt64(&c.load) } +func (c *mcuProxyConnection) IsShutdownScheduled() bool { + return atomic.LoadUint32(&c.shutdownScheduled) != 0 +} + func (c *mcuProxyConnection) readPump() { defer func() { if atomic.LoadUint32(&c.closed) == 0 { @@ -467,6 +472,7 @@ func (c *mcuProxyConnection) reconnect() { c.mu.Unlock() atomic.StoreInt64(&c.reconnectInterval, int64(initialReconnectInterval)) + atomic.StoreUint32(&c.shutdownScheduled, 0) if err := c.sendHello(); err != nil { log.Printf("Could not send hello request to %s: %s", c.url, err) c.scheduleReconnect() @@ -608,22 +614,27 @@ func (c *mcuProxyConnection) processPayload(msg *ProxyServerMessage) { func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) { event := msg.Event - if event.Type == "backend-disconnected" { + switch event.Type { + case "backend-disconnected": log.Printf("Upstream backend at %s got disconnected, reset MCU objects", c.url) c.clearPublishers() c.clearSubscribers() c.clearCallbacks() // TODO: Should we also reconnect? return - } else if event.Type == "backend-connected" { + case "backend-connected": log.Printf("Upstream backend at %s is connected", c.url) return - } else if event.Type == "update-load" { + case "update-load": if proxyDebugMessages { log.Printf("Load of %s now at %d", c.url, event.Load) } atomic.StoreInt64(&c.load, event.Load) return + case "shutdown-scheduled": + log.Printf("Proxy %s is scheduled to shutdown", c.url) + atomic.StoreUint32(&c.shutdownScheduled, 1) + return } if proxyDebugMessages { @@ -972,6 +983,10 @@ func (m *mcuProxy) removeWaiter(id uint64) { func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { connections := m.getSortedConnections() for _, conn := range connections { + if conn.IsShutdownScheduled() { + continue + } + publisher, err := conn.newPublisher(ctx, listener, id, streamType) if err != nil { log.Printf("Could not create %s publisher for %s on %s: %s", streamType, id, conn.url, err) From ea74a54d11a878bc59cf16785dab5576bc151904 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 12 Aug 2020 15:42:11 +0200 Subject: [PATCH 8/8] Select proxy based on country of publisher (if known). The connections that have been sorted by load are also sorted by country of publisher and continent of publisher, e.g. for worldwide proxies, the ones closest to the publisher will be preferred. --- src/proxy/proxy_server.go | 13 ++++- src/signaling/client.go | 15 ++++++ src/signaling/clientsession.go | 7 ++- src/signaling/hub.go | 5 +- src/signaling/mcu_common.go | 6 ++- src/signaling/mcu_janus.go | 2 +- src/signaling/mcu_proxy.go | 80 ++++++++++++++++++++++++++++-- src/signaling/mcu_proxy_test.go | 86 +++++++++++++++++++++++++++++++++ src/signaling/mcu_test.go | 2 +- 9 files changed, 205 insertions(+), 11 deletions(-) create mode 100644 src/signaling/mcu_proxy_test.go diff --git a/src/proxy/proxy_server.go b/src/proxy/proxy_server.go index 802c916..05c4fa0 100644 --- a/src/proxy/proxy_server.go +++ b/src/proxy/proxy_server.go @@ -178,8 +178,11 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile, na } country, _ := config.GetString("app", "country") - if country != "" { + country = strings.ToUpper(country) + if signaling.IsValidCountry(country) { log.Printf("Sending %s as country information", country) + } else if country != "" { + return nil, fmt.Errorf("Invalid country: %s", country) } else { log.Printf("Not sending country information") } @@ -664,6 +667,12 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { } } +type emptyInitiator struct{} + +func (i *emptyInitiator) Country() string { + return "" +} + func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { cmd := message.Command switch cmd.Type { @@ -674,7 +683,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } id := uuid.New().String() - publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType) + publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.StreamType, &emptyInitiator{}) if err == context.DeadlineExceeded { log.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId()) session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingPublisher)) diff --git a/src/signaling/client.go b/src/signaling/client.go index 187ca2d..ebfef43 100644 --- a/src/signaling/client.go +++ b/src/signaling/client.go @@ -58,6 +58,21 @@ var ( unknownCountry string = "unknown-country" ) +func IsValidCountry(country string) bool { + switch country { + case "": + fallthrough + case noCountry: + fallthrough + case loopback: + fallthrough + case unknownCountry: + return false + default: + return true + } +} + var ( InvalidFormat = NewError("invalid_format", "Invalid data format.") diff --git a/src/signaling/clientsession.go b/src/signaling/clientsession.go index 7a9d2fc..8a0b496 100644 --- a/src/signaling/clientsession.go +++ b/src/signaling/clientsession.go @@ -436,6 +436,10 @@ func (s *ClientSession) GetClient() *Client { s.mu.Lock() defer s.mu.Unlock() + return s.getClientUnlocked() +} + +func (s *ClientSession) getClientUnlocked() *Client { return s.client } @@ -554,9 +558,10 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea publisher, found := s.publishers[streamType] if !found { + client := s.getClientUnlocked() s.mu.Unlock() var err error - publisher, err = mcu.NewPublisher(ctx, s, s.PublicId(), streamType) + publisher, err = mcu.NewPublisher(ctx, s, s.PublicId(), streamType, client) s.mu.Lock() if err != nil { return nil, err diff --git a/src/signaling/hub.go b/src/signaling/hub.go index 0bbd173..9e352f9 100644 --- a/src/signaling/hub.go +++ b/src/signaling/hub.go @@ -1544,10 +1544,13 @@ func (h *Hub) lookupClientCountry(client *Client) string { country, err := h.geoip.LookupCountry(ip) if err != nil { - log.Printf("Could not lookup country for %s", ip) + log.Printf("Could not lookup country for %s: %s", ip, err) return unknownCountry } + if country == "" { + return unknownCountry + } return country } diff --git a/src/signaling/mcu_common.go b/src/signaling/mcu_common.go index 20721af..c821ff0 100644 --- a/src/signaling/mcu_common.go +++ b/src/signaling/mcu_common.go @@ -48,6 +48,10 @@ type McuListener interface { SubscriberClosed(subscriber McuSubscriber) } +type McuInitiator interface { + Country() string +} + type Mcu interface { Start() error Stop() @@ -57,7 +61,7 @@ type Mcu interface { GetStats() interface{} - NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) + NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) } diff --git a/src/signaling/mcu_janus.go b/src/signaling/mcu_janus.go index 84bd18d..6f59524 100644 --- a/src/signaling/mcu_janus.go +++ b/src/signaling/mcu_janus.go @@ -635,7 +635,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id string, st return handle, response.Session, roomId, nil } -func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { +func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { if _, found := streamTypeUserIds[streamType]; !found { return nil, fmt.Errorf("Unsupported stream type %s", streamType) } diff --git a/src/signaling/mcu_proxy.go b/src/signaling/mcu_proxy.go index 032d1f1..145c4f9 100644 --- a/src/signaling/mcu_proxy.go +++ b/src/signaling/mcu_proxy.go @@ -260,6 +260,7 @@ type mcuProxyConnection struct { helloMsgId string sessionId string load int64 + country atomic.Value callbacks map[string]func(*ProxyServerMessage) @@ -289,6 +290,7 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string) (*mcuProxyConnection publisherIds: make(map[string]string), subscribers: make(map[string]*mcuProxySubscriber), } + conn.country.Store("") return conn, nil } @@ -324,6 +326,10 @@ func (c *mcuProxyConnection) Load() int64 { return atomic.LoadInt64(&c.load) } +func (c *mcuProxyConnection) Country() string { + return c.country.Load().(string) +} + func (c *mcuProxyConnection) IsShutdownScheduled() bool { return atomic.LoadUint32(&c.shutdownScheduled) != 0 } @@ -564,7 +570,19 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { c.scheduleReconnect() case "hello": c.sessionId = msg.Hello.SessionId - log.Printf("Received session %s from %s", c.sessionId, c.url) + country := "" + if msg.Hello.Server != nil { + if country = msg.Hello.Server.Country; country != "" && !IsValidCountry(country) { + log.Printf("Proxy %s sent invalid country %s in hello response", c.url, country) + country = "" + } + } + c.country.Store(country) + if country != "" { + log.Printf("Received session %s from %s (in %s)", c.sessionId, c.url, country) + } else { + log.Printf("Received session %s from %s", c.sessionId, c.url) + } default: log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c.url) c.scheduleReconnect() @@ -930,7 +948,56 @@ func (l mcuProxyConnectionsList) Sort() { sort.Sort(l) } -func (m *mcuProxy) getSortedConnections() []*mcuProxyConnection { +func ContinentsOverlap(a, b []string) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + + for _, checkA := range a { + for _, checkB := range b { + if checkA == checkB { + return true + } + } + } + return false +} + +func sortConnectionsForCountry(connections []*mcuProxyConnection, country string) []*mcuProxyConnection { + // Move connections in the same country to the start of the list. + sorted := make(mcuProxyConnectionsList, 0, len(connections)) + unprocessed := make(mcuProxyConnectionsList, 0, len(connections)) + for _, conn := range connections { + if country == conn.Country() { + sorted = append(sorted, conn) + } else { + unprocessed = append(unprocessed, conn) + } + } + if continents, found := ContinentMap[country]; found && len(unprocessed) > 1 { + remaining := make(mcuProxyConnectionsList, 0, len(unprocessed)) + // Next up are connections on the same continent. + for _, conn := range unprocessed { + connCountry := conn.Country() + if IsValidCountry(connCountry) { + connContinents := ContinentMap[connCountry] + if ContinentsOverlap(continents, connContinents) { + sorted = append(sorted, conn) + } else { + remaining = append(remaining, conn) + } + } else { + remaining = append(remaining, conn) + } + } + unprocessed = remaining + } + // Add all other connections by load. + sorted = append(sorted, unprocessed...) + return sorted +} + +func (m *mcuProxy) getSortedConnections(initiator McuInitiator) []*mcuProxyConnection { connections := m.getConnections() if len(connections) < 2 { return connections @@ -951,6 +1018,11 @@ func (m *mcuProxy) getSortedConnections() []*mcuProxyConnection { connections = sorted } + if initiator != nil { + if country := initiator.Country(); IsValidCountry(country) { + connections = sortConnectionsForCountry(connections, country) + } + } return connections } @@ -980,8 +1052,8 @@ func (m *mcuProxy) removeWaiter(id uint64) { delete(m.publisherWaiters, id) } -func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { - connections := m.getSortedConnections() +func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { + connections := m.getSortedConnections(initiator) for _, conn := range connections { if conn.IsShutdownScheduled() { continue diff --git a/src/signaling/mcu_proxy_test.go b/src/signaling/mcu_proxy_test.go new file mode 100644 index 0000000..bf1cbca --- /dev/null +++ b/src/signaling/mcu_proxy_test.go @@ -0,0 +1,86 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2020 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "testing" +) + +func newProxyConnectionWithCountry(country string) *mcuProxyConnection { + conn := &mcuProxyConnection{} + conn.country.Store(country) + return conn +} + +func Test_sortConnectionsForCountry(t *testing.T) { + conn_de := newProxyConnectionWithCountry("DE") + conn_at := newProxyConnectionWithCountry("AT") + conn_jp := newProxyConnectionWithCountry("JP") + conn_us := newProxyConnectionWithCountry("US") + + testcases := map[string][][]*mcuProxyConnection{ + // Direct country match + "DE": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_at, conn_jp, conn_de}, + []*mcuProxyConnection{conn_de, conn_at, conn_jp}, + }, + // Direct country match + "AT": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_at, conn_jp, conn_de}, + []*mcuProxyConnection{conn_at, conn_de, conn_jp}, + }, + // Continent match + "CH": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_de, conn_at, conn_jp}, + }, + // Direct country match + "JP": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_jp, conn_de, conn_at}, + }, + // Continent match + "CN": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_jp, conn_de, conn_at}, + }, + // Partial continent match + "RU": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_de, conn_jp, conn_at, conn_us}, + }, + // No match + "AU": [][]*mcuProxyConnection{ + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + []*mcuProxyConnection{conn_us, conn_de, conn_jp, conn_at}, + }, + } + + for country, test := range testcases { + sorted := sortConnectionsForCountry(test[0], country) + for idx, conn := range sorted { + if test[1][idx] != conn { + t.Errorf("Index %d for %s: expected %s, got %s", idx, country, test[1][idx].Country(), conn.Country()) + } + } + } +} diff --git a/src/signaling/mcu_test.go b/src/signaling/mcu_test.go index 7062a49..dbfe485 100644 --- a/src/signaling/mcu_test.go +++ b/src/signaling/mcu_test.go @@ -51,7 +51,7 @@ func (m *TestMCU) GetStats() interface{} { return nil } -func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) { +func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) { return nil, fmt.Errorf("Not implemented") }