nextcloud-spreed-signaling/mcu_proxy.go

1643 lines
41 KiB
Go

/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2020 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"crypto/rsa"
"crypto/tls"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/dlintw/goconf"
"github.com/gorilla/websocket"
"go.etcd.io/etcd/clientv3"
"go.etcd.io/etcd/pkg/srv"
"go.etcd.io/etcd/pkg/transport"
"gopkg.in/dgrijalva/jwt-go.v3"
)
const (
closeTimeout = time.Second
proxyDebugMessages = false
// Very high value so the connections get sorted at the end.
loadNotConnected = 1000000
// Sort connections by load every 10 publishing requests or once per second.
connectionSortRequests = 10
connectionSortInterval = time.Second
proxyUrlTypeStatic = "static"
proxyUrlTypeEtcd = "etcd"
initialWaitDelay = time.Second
maxWaitDelay = 8 * time.Second
defaultProxyTimeoutSeconds = 2
)
type mcuProxyPubSubCommon struct {
streamType string
proxyId string
conn *mcuProxyConnection
listener McuListener
}
func (c *mcuProxyPubSubCommon) Id() string {
return c.proxyId
}
func (c *mcuProxyPubSubCommon) StreamType() string {
return c.streamType
}
func (c *mcuProxyPubSubCommon) doSendMessage(ctx context.Context, msg *ProxyClientMessage, callback func(error, map[string]interface{})) {
c.conn.performAsyncRequest(ctx, msg, func(err error, response *ProxyServerMessage) {
if err != nil {
callback(err, nil)
return
}
if proxyDebugMessages {
log.Printf("Response from %s: %+v", c.conn.url, response)
}
if response.Type == "error" {
callback(response.Error, nil)
} else if response.Payload != nil {
callback(nil, response.Payload.Payload)
} else {
callback(nil, nil)
}
})
}
func (c *mcuProxyPubSubCommon) doProcessPayload(client McuClient, msg *PayloadProxyServerMessage) {
switch msg.Type {
case "candidate":
c.listener.OnIceCandidate(client, msg.Payload["candidate"])
default:
log.Printf("Unsupported payload from %s: %+v", c.conn.url, msg)
}
}
type mcuProxyPublisher struct {
mcuProxyPubSubCommon
id string
}
func newMcuProxyPublisher(id string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher {
return &mcuProxyPublisher{
mcuProxyPubSubCommon: mcuProxyPubSubCommon{
streamType: streamType,
proxyId: proxyId,
conn: conn,
listener: listener,
},
id: id,
}
}
func (p *mcuProxyPublisher) NotifyClosed() {
p.listener.PublisherClosed(p)
p.conn.removePublisher(p)
}
func (p *mcuProxyPublisher) Close(ctx context.Context) {
p.NotifyClosed()
msg := &ProxyClientMessage{
Type: "command",
Command: &CommandProxyClientMessage{
Type: "delete-publisher",
ClientId: p.proxyId,
},
}
if _, err := p.conn.performSyncRequest(ctx, msg); err != nil {
log.Printf("Could not delete publisher %s at %s: %s", p.proxyId, p.conn.url, err)
return
}
log.Printf("Delete publisher %s at %s", p.proxyId, p.conn.url)
}
func (p *mcuProxyPublisher) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) {
msg := &ProxyClientMessage{
Type: "payload",
Payload: &PayloadProxyClientMessage{
Type: data.Type,
ClientId: p.proxyId,
Payload: data.Payload,
},
}
p.doSendMessage(ctx, msg, callback)
}
func (p *mcuProxyPublisher) ProcessPayload(msg *PayloadProxyServerMessage) {
p.doProcessPayload(p, msg)
}
func (p *mcuProxyPublisher) ProcessEvent(msg *EventProxyServerMessage) {
switch msg.Type {
case "ice-completed":
p.listener.OnIceCompleted(p)
case "publisher-closed":
p.NotifyClosed()
default:
log.Printf("Unsupported event from %s: %+v", p.conn.url, msg)
}
}
type mcuProxySubscriber struct {
mcuProxyPubSubCommon
publisherId string
}
func newMcuProxySubscriber(publisherId string, streamType string, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxySubscriber {
return &mcuProxySubscriber{
mcuProxyPubSubCommon: mcuProxyPubSubCommon{
streamType: streamType,
proxyId: proxyId,
conn: conn,
listener: listener,
},
publisherId: publisherId,
}
}
func (s *mcuProxySubscriber) Publisher() string {
return s.publisherId
}
func (s *mcuProxySubscriber) NotifyClosed() {
s.listener.SubscriberClosed(s)
s.conn.removeSubscriber(s)
}
func (s *mcuProxySubscriber) Close(ctx context.Context) {
s.NotifyClosed()
msg := &ProxyClientMessage{
Type: "command",
Command: &CommandProxyClientMessage{
Type: "delete-subscriber",
ClientId: s.proxyId,
},
}
if _, err := s.conn.performSyncRequest(ctx, msg); err != nil {
log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn.url, err)
return
}
log.Printf("Delete subscriber %s at %s", s.proxyId, s.conn.url)
}
func (s *mcuProxySubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) {
msg := &ProxyClientMessage{
Type: "payload",
Payload: &PayloadProxyClientMessage{
Type: data.Type,
ClientId: s.proxyId,
Payload: data.Payload,
},
}
s.doSendMessage(ctx, msg, callback)
}
func (s *mcuProxySubscriber) ProcessPayload(msg *PayloadProxyServerMessage) {
s.doProcessPayload(s, msg)
}
func (s *mcuProxySubscriber) ProcessEvent(msg *EventProxyServerMessage) {
switch msg.Type {
case "ice-completed":
s.listener.OnIceCompleted(s)
case "subscriber-closed":
s.NotifyClosed()
default:
log.Printf("Unsupported event from %s: %+v", s.conn.url, msg)
}
}
type mcuProxyConnection struct {
// 64-bit members that are accessed atomically must be 64-bit aligned.
reconnectInterval int64
msgId int64
load int64
proxy *mcuProxy
rawUrl string
url *url.URL
mu sync.Mutex
closeChan chan bool
closedChan chan bool
closed uint32
conn *websocket.Conn
connectedSince time.Time
reconnectTimer *time.Timer
shutdownScheduled uint32
closeScheduled uint32
helloMsgId string
sessionId string
country atomic.Value
callbacks map[string]func(*ProxyServerMessage)
publishersLock sync.RWMutex
publishers map[string]*mcuProxyPublisher
publisherIds map[string]string
subscribersLock sync.RWMutex
subscribers map[string]*mcuProxySubscriber
}
func newMcuProxyConnection(proxy *mcuProxy, baseUrl string) (*mcuProxyConnection, error) {
parsed, err := url.Parse(baseUrl)
if err != nil {
return nil, err
}
conn := &mcuProxyConnection{
proxy: proxy,
rawUrl: baseUrl,
url: parsed,
closeChan: make(chan bool, 1),
closedChan: make(chan bool, 1),
reconnectInterval: int64(initialReconnectInterval),
load: loadNotConnected,
callbacks: make(map[string]func(*ProxyServerMessage)),
publishers: make(map[string]*mcuProxyPublisher),
publisherIds: make(map[string]string),
subscribers: make(map[string]*mcuProxySubscriber),
}
conn.country.Store("")
return conn, nil
}
type mcuProxyConnectionStats struct {
Url string `json:"url"`
Connected bool `json:"connected"`
Publishers int64 `json:"publishers"`
Clients int64 `json:"clients"`
Load *int64 `json:"load,omitempty"`
Shutdown *bool `json:"shutdown,omitempty"`
Uptime *time.Time `json:"uptime,omitempty"`
}
func (c *mcuProxyConnection) GetStats() *mcuProxyConnectionStats {
result := &mcuProxyConnectionStats{
Url: c.url.String(),
}
c.mu.Lock()
if c.conn != nil {
result.Connected = true
result.Uptime = &c.connectedSince
load := c.Load()
result.Load = &load
shutdown := c.IsShutdownScheduled()
result.Shutdown = &shutdown
}
c.mu.Unlock()
c.publishersLock.RLock()
result.Publishers = int64(len(c.publishers))
c.publishersLock.RUnlock()
c.subscribersLock.RLock()
result.Clients = int64(len(c.subscribers))
c.subscribersLock.RUnlock()
result.Clients += result.Publishers
return result
}
func (c *mcuProxyConnection) Load() int64 {
return atomic.LoadInt64(&c.load)
}
func (c *mcuProxyConnection) Country() string {
return c.country.Load().(string)
}
func (c *mcuProxyConnection) IsShutdownScheduled() bool {
return atomic.LoadUint32(&c.shutdownScheduled) != 0 || atomic.LoadUint32(&c.closeScheduled) != 0
}
func (c *mcuProxyConnection) readPump() {
defer func() {
if atomic.LoadUint32(&c.closed) == 0 {
c.scheduleReconnect()
} else {
c.closedChan <- true
}
}()
defer c.close()
defer atomic.StoreInt64(&c.load, loadNotConnected)
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
conn.SetPongHandler(func(msg string) error {
now := time.Now()
conn.SetReadDeadline(now.Add(pongWait))
if msg == "" {
return nil
}
if ts, err := strconv.ParseInt(msg, 10, 64); err == nil {
rtt := now.Sub(time.Unix(0, ts))
rtt_ms := rtt.Nanoseconds() / time.Millisecond.Nanoseconds()
log.Printf("Proxy at %s has RTT of %d ms (%s)", c.url, rtt_ms, rtt)
}
return nil
})
for {
conn.SetReadDeadline(time.Now().Add(pongWait))
_, message, err := conn.ReadMessage()
if err != nil {
if _, ok := err.(*websocket.CloseError); !ok || websocket.IsUnexpectedCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived) {
log.Printf("Error reading from %s: %v", c.url, err)
}
break
}
var msg ProxyServerMessage
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("Error unmarshaling %s from %s: %s", string(message), c.url, err)
continue
}
c.processMessage(&msg)
}
}
func (c *mcuProxyConnection) sendPing() bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return false
}
now := time.Now()
msg := strconv.FormatInt(now.UnixNano(), 10)
c.conn.SetWriteDeadline(now.Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil {
log.Printf("Could not send ping to proxy at %s: %v", c.url, err)
c.scheduleReconnect()
return false
}
return true
}
func (c *mcuProxyConnection) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
}()
c.reconnectTimer = time.NewTimer(0)
for {
select {
case <-c.reconnectTimer.C:
c.reconnect()
case <-ticker.C:
c.sendPing()
case <-c.closeChan:
return
}
}
}
func (c *mcuProxyConnection) start() error {
go c.writePump()
return nil
}
func (c *mcuProxyConnection) sendClose() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return ErrNotConnected
}
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
return c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
}
func (c *mcuProxyConnection) stop(ctx context.Context) {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return
}
c.closeChan <- true
if err := c.sendClose(); err != nil {
if err != ErrNotConnected {
log.Printf("Could not send close message to %s: %s", c.url, err)
}
c.close()
return
}
select {
case <-c.closedChan:
case <-ctx.Done():
if err := ctx.Err(); err != nil {
log.Printf("Error waiting for connection to %s get closed: %s", c.url, err)
c.close()
}
}
}
func (c *mcuProxyConnection) close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
}
func (c *mcuProxyConnection) stopCloseIfEmpty() {
atomic.StoreUint32(&c.closeScheduled, 0)
}
func (c *mcuProxyConnection) closeIfEmpty() bool {
atomic.StoreUint32(&c.closeScheduled, 1)
var total int64
c.publishersLock.RLock()
total += int64(len(c.publishers))
c.publishersLock.RUnlock()
c.subscribersLock.RLock()
total += int64(len(c.subscribers))
c.subscribersLock.RUnlock()
if total > 0 {
// Connection will be closed once all clients have disconnected.
log.Printf("Connection to %s is still used by %d clients, defer closing", c.url, total)
return false
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), closeTimeout)
defer cancel()
log.Printf("All clients disconnected, closing connection to %s", c.url)
c.stop(ctx)
c.proxy.removeConnection(c)
}()
return true
}
func (c *mcuProxyConnection) scheduleReconnect() {
if err := c.sendClose(); err != nil && err != ErrNotConnected {
log.Printf("Could not send close message to %s: %s", c.url, err)
c.close()
}
interval := atomic.LoadInt64(&c.reconnectInterval)
c.reconnectTimer.Reset(time.Duration(interval))
interval = interval * 2
if interval > int64(maxReconnectInterval) {
interval = int64(maxReconnectInterval)
}
atomic.StoreInt64(&c.reconnectInterval, interval)
}
func (c *mcuProxyConnection) reconnect() {
u, err := c.url.Parse("proxy")
if err != nil {
log.Printf("Could not resolve url to proxy at %s: %s", c.url, err)
c.scheduleReconnect()
return
}
if u.Scheme == "http" {
u.Scheme = "ws"
} else if u.Scheme == "https" {
u.Scheme = "wss"
}
conn, _, err := c.proxy.dialer.Dial(u.String(), nil)
if err != nil {
log.Printf("Could not connect to %s: %s", u, err)
c.scheduleReconnect()
return
}
log.Printf("Connected to %s", u)
atomic.StoreUint32(&c.closed, 0)
c.mu.Lock()
c.connectedSince = time.Now()
c.conn = conn
c.mu.Unlock()
atomic.StoreInt64(&c.reconnectInterval, int64(initialReconnectInterval))
atomic.StoreUint32(&c.shutdownScheduled, 0)
if err := c.sendHello(); err != nil {
log.Printf("Could not send hello request to %s: %s", c.url, err)
c.scheduleReconnect()
return
}
if !c.sendPing() {
return
}
go c.readPump()
}
func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) {
c.proxy.removePublisher(publisher)
c.publishersLock.Lock()
defer c.publishersLock.Unlock()
delete(c.publishers, publisher.proxyId)
delete(c.publisherIds, publisher.id+"|"+publisher.StreamType())
if len(c.publishers) == 0 && atomic.LoadUint32(&c.closeScheduled) != 0 {
go c.closeIfEmpty()
}
}
func (c *mcuProxyConnection) clearPublishers() {
c.publishersLock.Lock()
defer c.publishersLock.Unlock()
go func(publishers map[string]*mcuProxyPublisher) {
for _, publisher := range publishers {
publisher.NotifyClosed()
}
}(c.publishers)
c.publishers = make(map[string]*mcuProxyPublisher)
c.publisherIds = make(map[string]string)
if atomic.LoadUint32(&c.closeScheduled) != 0 {
go c.closeIfEmpty()
}
}
func (c *mcuProxyConnection) removeSubscriber(subscriber *mcuProxySubscriber) {
c.subscribersLock.Lock()
defer c.subscribersLock.Unlock()
delete(c.subscribers, subscriber.proxyId)
if len(c.subscribers) == 0 && atomic.LoadUint32(&c.closeScheduled) != 0 {
go c.closeIfEmpty()
}
}
func (c *mcuProxyConnection) clearSubscribers() {
c.subscribersLock.Lock()
defer c.subscribersLock.Unlock()
go func(subscribers map[string]*mcuProxySubscriber) {
for _, subscriber := range subscribers {
subscriber.NotifyClosed()
}
}(c.subscribers)
c.subscribers = make(map[string]*mcuProxySubscriber)
if atomic.LoadUint32(&c.closeScheduled) != 0 {
go c.closeIfEmpty()
}
}
func (c *mcuProxyConnection) clearCallbacks() {
c.mu.Lock()
defer c.mu.Unlock()
c.callbacks = make(map[string]func(*ProxyServerMessage))
}
func (c *mcuProxyConnection) getCallback(id string) func(*ProxyServerMessage) {
c.mu.Lock()
defer c.mu.Unlock()
callback, found := c.callbacks[id]
if found {
delete(c.callbacks, id)
}
return callback
}
func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) {
if c.helloMsgId != "" && msg.Id == c.helloMsgId {
c.helloMsgId = ""
switch msg.Type {
case "error":
if msg.Error.Code == "no_such_session" {
log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c.url)
c.clearPublishers()
c.clearSubscribers()
c.clearCallbacks()
c.sessionId = ""
if err := c.sendHello(); err != nil {
log.Printf("Could not send hello request to %s: %s", c.url, err)
c.scheduleReconnect()
}
return
}
log.Printf("Hello connection to %s failed with %+v, reconnecting", c.url, msg.Error)
c.scheduleReconnect()
case "hello":
resumed := c.sessionId == msg.Hello.SessionId
c.sessionId = msg.Hello.SessionId
country := ""
if msg.Hello.Server != nil {
if country = msg.Hello.Server.Country; country != "" && !IsValidCountry(country) {
log.Printf("Proxy %s sent invalid country %s in hello response", c.url, country)
country = ""
}
}
c.country.Store(country)
if resumed {
log.Printf("Resumed session %s on %s", c.sessionId, c.url)
} else if country != "" {
log.Printf("Received session %s from %s (in %s)", c.sessionId, c.url, country)
} else {
log.Printf("Received session %s from %s", c.sessionId, c.url)
}
default:
log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c.url)
c.scheduleReconnect()
}
return
}
if proxyDebugMessages {
log.Printf("Received from %s: %+v", c.url, msg)
}
callback := c.getCallback(msg.Id)
if callback != nil {
callback(msg)
return
}
switch msg.Type {
case "payload":
c.processPayload(msg)
case "event":
c.processEvent(msg)
case "bye":
c.processBye(msg)
default:
log.Printf("Unsupported message received from %s: %+v", c.url, msg)
}
}
func (c *mcuProxyConnection) processPayload(msg *ProxyServerMessage) {
payload := msg.Payload
c.publishersLock.RLock()
publisher, found := c.publishers[payload.ClientId]
c.publishersLock.RUnlock()
if found {
publisher.ProcessPayload(payload)
return
}
c.subscribersLock.RLock()
subscriber, found := c.subscribers[payload.ClientId]
c.subscribersLock.RUnlock()
if found {
subscriber.ProcessPayload(payload)
return
}
log.Printf("Received payload for unknown client %+v from %s", payload, c.url)
}
func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) {
event := msg.Event
switch event.Type {
case "backend-disconnected":
log.Printf("Upstream backend at %s got disconnected, reset MCU objects", c.url)
c.clearPublishers()
c.clearSubscribers()
c.clearCallbacks()
// TODO: Should we also reconnect?
return
case "backend-connected":
log.Printf("Upstream backend at %s is connected", c.url)
return
case "update-load":
if proxyDebugMessages {
log.Printf("Load of %s now at %d", c.url, event.Load)
}
atomic.StoreInt64(&c.load, event.Load)
return
case "shutdown-scheduled":
log.Printf("Proxy %s is scheduled to shutdown", c.url)
atomic.StoreUint32(&c.shutdownScheduled, 1)
return
}
if proxyDebugMessages {
log.Printf("Process event from %s: %+v", c.url, event)
}
c.publishersLock.RLock()
publisher, found := c.publishers[event.ClientId]
c.publishersLock.RUnlock()
if found {
publisher.ProcessEvent(event)
return
}
c.subscribersLock.RLock()
subscriber, found := c.subscribers[event.ClientId]
c.subscribersLock.RUnlock()
if found {
subscriber.ProcessEvent(event)
return
}
log.Printf("Received event for unknown client %+v from %s", event, c.url)
}
func (c *mcuProxyConnection) processBye(msg *ProxyServerMessage) {
bye := msg.Bye
switch bye.Reason {
case "session_resumed":
log.Printf("Session %s on %s was resumed by other client, resetting", c.sessionId, c.url)
c.sessionId = ""
default:
log.Printf("Received bye with unsupported reason from %s %+v", c.url, bye)
}
}
func (c *mcuProxyConnection) sendHello() error {
c.helloMsgId = strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10)
msg := &ProxyClientMessage{
Id: c.helloMsgId,
Type: "hello",
Hello: &HelloProxyClientMessage{
Version: "1.0",
},
}
if c.sessionId != "" {
msg.Hello.ResumeId = c.sessionId
} else {
claims := &TokenClaims{
jwt.StandardClaims{
IssuedAt: time.Now().Unix(),
Issuer: c.proxy.tokenId,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenString, err := token.SignedString(c.proxy.tokenKey)
if err != nil {
return err
}
msg.Hello.Token = tokenString
}
return c.sendMessage(msg)
}
func (c *mcuProxyConnection) sendMessage(msg *ProxyClientMessage) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.sendMessageLocked(msg)
}
func (c *mcuProxyConnection) sendMessageLocked(msg *ProxyClientMessage) error {
if proxyDebugMessages {
log.Printf("Send message to %s: %+v", c.url, msg)
}
if c.conn == nil {
return ErrNotConnected
}
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
return c.conn.WriteJSON(msg)
}
func (c *mcuProxyConnection) performAsyncRequest(ctx context.Context, msg *ProxyClientMessage, callback func(err error, response *ProxyServerMessage)) {
msgId := strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10)
msg.Id = msgId
c.mu.Lock()
defer c.mu.Unlock()
c.callbacks[msgId] = func(msg *ProxyServerMessage) {
callback(nil, msg)
}
if err := c.sendMessageLocked(msg); err != nil {
delete(c.callbacks, msgId)
go callback(err, nil)
return
}
}
func (c *mcuProxyConnection) performSyncRequest(ctx context.Context, msg *ProxyClientMessage) (*ProxyServerMessage, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
errChan := make(chan error, 1)
responseChan := make(chan *ProxyServerMessage, 1)
c.performAsyncRequest(ctx, msg, func(err error, response *ProxyServerMessage) {
if err != nil {
errChan <- err
} else {
responseChan <- response
}
})
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errChan:
return nil, err
case response := <-responseChan:
return response, nil
}
}
func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id string, streamType string) (McuPublisher, error) {
msg := &ProxyClientMessage{
Type: "command",
Command: &CommandProxyClientMessage{
Type: "create-publisher",
StreamType: streamType,
},
}
response, err := c.performSyncRequest(ctx, msg)
if err != nil {
// TODO: Cancel request
return nil, err
}
proxyId := response.Command.Id
log.Printf("Created %s publisher %s on %s for %s", streamType, proxyId, c.url, id)
publisher := newMcuProxyPublisher(id, streamType, proxyId, c, listener)
c.publishersLock.Lock()
c.publishers[proxyId] = publisher
c.publisherIds[id+"|"+streamType] = proxyId
c.publishersLock.Unlock()
return publisher, nil
}
func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) {
c.publishersLock.Lock()
id, found := c.publisherIds[publisher+"|"+streamType]
c.publishersLock.Unlock()
if !found {
return nil, fmt.Errorf("Unknown publisher %s", publisher)
}
msg := &ProxyClientMessage{
Type: "command",
Command: &CommandProxyClientMessage{
Type: "create-subscriber",
StreamType: streamType,
PublisherId: id,
},
}
response, err := c.performSyncRequest(ctx, msg)
if err != nil {
// TODO: Cancel request
return nil, err
}
proxyId := response.Command.Id
log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c.url, publisher)
subscriber := newMcuProxySubscriber(publisher, streamType, proxyId, c, listener)
c.subscribersLock.Lock()
c.subscribers[proxyId] = subscriber
c.subscribersLock.Unlock()
return subscriber, nil
}
type mcuProxy struct {
// 64-bit members that are accessed atomically must be 64-bit aligned.
connRequests int64
nextSort int64
tokenId string
tokenKey *rsa.PrivateKey
etcdMu sync.Mutex
client atomic.Value
keyPrefix atomic.Value
keyInfos map[string]*ProxyInformationEtcd
urlToKey map[string]string
dialer *websocket.Dialer
connections []*mcuProxyConnection
connectionsMap map[string]*mcuProxyConnection
connectionsMu sync.RWMutex
proxyTimeout time.Duration
mu sync.RWMutex
publishers map[string]*mcuProxyConnection
publisherWaitersId uint64
publisherWaiters map[uint64]chan bool
}
func NewMcuProxy(config *goconf.ConfigFile) (Mcu, error) {
urlType, _ := config.GetString("mcu", "urltype")
tokenId, _ := config.GetString("mcu", "token_id")
if tokenId == "" {
return nil, fmt.Errorf("No token id configured")
}
tokenKeyFilename, _ := config.GetString("mcu", "token_key")
if tokenKeyFilename == "" {
return nil, fmt.Errorf("No token key configured")
}
tokenKeyData, err := ioutil.ReadFile(tokenKeyFilename)
if err != nil {
return nil, fmt.Errorf("Could not read private key from %s: %s", tokenKeyFilename, err)
}
tokenKey, err := jwt.ParseRSAPrivateKeyFromPEM(tokenKeyData)
if err != nil {
return nil, fmt.Errorf("Could not parse private key from %s: %s", tokenKeyFilename, err)
}
proxyTimeoutSeconds, _ := config.GetInt("mcu", "proxytimeout")
if proxyTimeoutSeconds <= 0 {
proxyTimeoutSeconds = defaultProxyTimeoutSeconds
}
proxyTimeout := time.Duration(proxyTimeoutSeconds) * time.Second
log.Printf("Using a timeout of %s for proxy requests", proxyTimeout)
mcu := &mcuProxy{
tokenId: tokenId,
tokenKey: tokenKey,
dialer: &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: proxyTimeout,
},
connectionsMap: make(map[string]*mcuProxyConnection),
proxyTimeout: proxyTimeout,
publishers: make(map[string]*mcuProxyConnection),
publisherWaiters: make(map[uint64]chan bool),
}
skipverify, _ := config.GetBool("mcu", "skipverify")
if skipverify {
log.Println("WARNING: MCU verification is disabled!")
mcu.dialer.TLSClientConfig = &tls.Config{
InsecureSkipVerify: skipverify,
}
}
if urlType == "" {
urlType = proxyUrlTypeStatic
}
switch urlType {
case proxyUrlTypeStatic:
mcuUrl, _ := config.GetString("mcu", "url")
for _, u := range strings.Split(mcuUrl, " ") {
conn, err := newMcuProxyConnection(mcu, u)
if err != nil {
return nil, err
}
mcu.connections = append(mcu.connections, conn)
mcu.connectionsMap[u] = conn
}
if len(mcu.connections) == 0 {
return nil, fmt.Errorf("No MCU proxy connections configured")
}
case proxyUrlTypeEtcd:
mcu.keyInfos = make(map[string]*ProxyInformationEtcd)
mcu.urlToKey = make(map[string]string)
if err := mcu.configureEtcd(config, false); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("Unsupported proxy URL type %s", urlType)
}
return mcu, nil
}
func (m *mcuProxy) getEtcdClient() *clientv3.Client {
c := m.client.Load()
if c == nil {
return nil
}
return c.(*clientv3.Client)
}
func (m *mcuProxy) Start() error {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
for _, c := range m.connections {
if err := c.start(); err != nil {
return err
}
}
return nil
}
func (m *mcuProxy) Stop() {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
for _, c := range m.connections {
ctx, cancel := context.WithTimeout(context.Background(), closeTimeout)
defer cancel()
c.stop(ctx)
}
}
func (m *mcuProxy) configureEtcd(config *goconf.ConfigFile, ignoreErrors bool) error {
keyPrefix, _ := config.GetString("mcu", "keyprefix")
if keyPrefix == "" {
keyPrefix = "/%s"
}
var endpoints []string
if endpointsString, _ := config.GetString("mcu", "endpoints"); endpointsString != "" {
for _, ep := range strings.Split(endpointsString, ",") {
ep := strings.TrimSpace(ep)
if ep != "" {
endpoints = append(endpoints, ep)
}
}
} else if discoverySrv, _ := config.GetString("mcu", "discoverysrv"); discoverySrv != "" {
discoveryService, _ := config.GetString("mcu", "discoveryservice")
clients, err := srv.GetClient("etcd-client", discoverySrv, discoveryService)
if err != nil {
if !ignoreErrors {
return fmt.Errorf("Could not discover endpoints for %s: %s", discoverySrv, err)
}
} else {
endpoints = clients.Endpoints
}
}
if len(endpoints) == 0 {
if !ignoreErrors {
return fmt.Errorf("No proxy URL endpoints configured")
}
log.Printf("No proxy URL endpoints configured, not changing client")
} else {
cfg := clientv3.Config{
Endpoints: endpoints,
// set timeout per request to fail fast when the target endpoint is unavailable
DialTimeout: time.Second,
}
clientKey, _ := config.GetString("mcu", "clientkey")
clientCert, _ := config.GetString("mcu", "clientcert")
caCert, _ := config.GetString("mcu", "cacert")
if clientKey != "" && clientCert != "" && caCert != "" {
tlsInfo := transport.TLSInfo{
CertFile: clientCert,
KeyFile: clientKey,
TrustedCAFile: caCert,
}
tlsConfig, err := tlsInfo.ClientConfig()
if err != nil {
if !ignoreErrors {
return fmt.Errorf("Could not setup TLS configuration: %s", err)
}
log.Printf("Could not setup TLS configuration, will be disabled (%s)", err)
} else {
cfg.TLS = tlsConfig
}
}
c, err := clientv3.New(cfg)
if err != nil {
if !ignoreErrors {
return err
}
log.Printf("Could not create new client from proxy URL endpoints %+v: %s", endpoints, err)
} else {
prev := m.getEtcdClient()
if prev != nil {
prev.Close()
}
m.client.Store(c)
log.Printf("Using proxy URL endpoints %+v", endpoints)
go func(client *clientv3.Client) {
log.Printf("Wait for leader and start watching on %s", keyPrefix)
ch := client.Watch(clientv3.WithRequireLeader(context.Background()), keyPrefix, clientv3.WithPrefix())
log.Printf("Watch created for %s", keyPrefix)
m.processWatches(ch)
}(c)
go func() {
m.waitForConnection()
waitDelay := initialWaitDelay
for {
response, err := m.getProxyUrls(keyPrefix)
if err != nil {
if err == context.DeadlineExceeded {
log.Printf("Timeout getting initial list of proxy URLs, retry in %s", waitDelay)
} else {
log.Printf("Could not get initial list of proxy URLs, retry in %s: %s", waitDelay, err)
}
time.Sleep(waitDelay)
waitDelay = waitDelay * 2
if waitDelay > maxWaitDelay {
waitDelay = maxWaitDelay
}
continue
}
for _, ev := range response.Kvs {
m.addEtcdProxy(string(ev.Key), ev.Value)
}
return
}
}()
}
}
return nil
}
func (m *mcuProxy) getProxyUrls(keyPrefix string) (*clientv3.GetResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return m.getEtcdClient().Get(ctx, keyPrefix, clientv3.WithPrefix())
}
func (m *mcuProxy) waitForConnection() {
waitDelay := initialWaitDelay
for {
if err := m.syncClient(); err != nil {
if err == context.DeadlineExceeded {
log.Printf("Timeout waiting for etcd client to connect to the cluster, retry in %s", waitDelay)
} else {
log.Printf("Could not sync etcd client with the cluster, retry in %s: %s", waitDelay, err)
}
time.Sleep(waitDelay)
waitDelay = waitDelay * 2
if waitDelay > maxWaitDelay {
waitDelay = maxWaitDelay
}
continue
}
log.Printf("Client using endpoints %+v", m.getEtcdClient().Endpoints())
return
}
}
func (m *mcuProxy) syncClient() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return m.getEtcdClient().Sync(ctx)
}
func (m *mcuProxy) Reload(config *goconf.ConfigFile) {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
remove := make(map[string]*mcuProxyConnection)
for u, conn := range m.connectionsMap {
remove[u] = conn
}
created := make(map[string]*mcuProxyConnection)
changed := false
mcuUrl, _ := config.GetString("mcu", "url")
for _, u := range strings.Split(mcuUrl, " ") {
if existing, found := remove[u]; found {
// Proxy connection still exists in new configuration
delete(remove, u)
existing.stopCloseIfEmpty()
continue
}
conn, err := newMcuProxyConnection(m, u)
if err != nil {
log.Printf("Could not create proxy connection to %s: %s", u, err)
continue
}
created[u] = conn
}
for _, conn := range remove {
go conn.closeIfEmpty()
}
for u, conn := range created {
if err := conn.start(); err != nil {
log.Printf("Could not start new connection to %s: %s", u, err)
continue
}
log.Printf("Adding new connection to %s", u)
m.connections = append(m.connections, conn)
m.connectionsMap[u] = conn
changed = true
}
if changed {
atomic.StoreInt64(&m.nextSort, 0)
}
}
func (m *mcuProxy) processWatches(ch clientv3.WatchChan) {
for response := range ch {
for _, ev := range response.Events {
switch ev.Type {
case clientv3.EventTypePut:
m.addEtcdProxy(string(ev.Kv.Key), ev.Kv.Value)
case clientv3.EventTypeDelete:
m.removeEtcdProxy(string(ev.Kv.Key))
default:
log.Printf("Unsupported event %s %q -> %q", ev.Type, ev.Kv.Key, ev.Kv.Value)
}
}
}
}
func (m *mcuProxy) addEtcdProxy(key string, data []byte) {
var info ProxyInformationEtcd
if err := json.Unmarshal(data, &info); err != nil {
log.Printf("Could not decode proxy information %s: %s", string(data), err)
return
}
if err := info.CheckValid(); err != nil {
log.Printf("Received invalid proxy information %s: %s", string(data), err)
return
}
m.etcdMu.Lock()
defer m.etcdMu.Unlock()
prev, found := m.keyInfos[key]
if found && info.Address != prev.Address {
// Address of a proxy has changed.
m.removeEtcdProxyLocked(key)
}
if otherKey, found := m.urlToKey[info.Address]; found && otherKey != key {
log.Printf("Address %s is already registered for key %s, ignoring %s", info.Address, otherKey, key)
return
}
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
if conn, found := m.connectionsMap[info.Address]; found {
m.keyInfos[key] = &info
m.urlToKey[info.Address] = key
conn.stopCloseIfEmpty()
} else {
conn, err := newMcuProxyConnection(m, info.Address)
if err != nil {
log.Printf("Could not create proxy connection to %s: %s", info.Address, err)
return
}
if err := conn.start(); err != nil {
log.Printf("Could not start new connection to %s: %s", info.Address, err)
return
}
log.Printf("Adding new connection to %s (from %s)", info.Address, key)
m.keyInfos[key] = &info
m.urlToKey[info.Address] = key
m.connections = append(m.connections, conn)
m.connectionsMap[info.Address] = conn
atomic.StoreInt64(&m.nextSort, 0)
}
}
func (m *mcuProxy) removeEtcdProxy(key string) {
m.etcdMu.Lock()
defer m.etcdMu.Unlock()
m.removeEtcdProxyLocked(key)
}
func (m *mcuProxy) removeEtcdProxyLocked(key string) {
info, found := m.keyInfos[key]
if !found {
return
}
delete(m.keyInfos, key)
delete(m.urlToKey, info.Address)
log.Printf("Removing connection to %s (from %s)", info.Address, key)
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
if conn, found := m.connectionsMap[info.Address]; found {
go conn.closeIfEmpty()
}
}
func (m *mcuProxy) removeConnection(c *mcuProxyConnection) {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
if _, found := m.connectionsMap[c.rawUrl]; found {
delete(m.connectionsMap, c.rawUrl)
m.connections = nil
for _, conn := range m.connectionsMap {
m.connections = append(m.connections, conn)
}
atomic.StoreInt64(&m.nextSort, 0)
}
}
func (m *mcuProxy) SetOnConnected(f func()) {
// Not supported.
}
func (m *mcuProxy) SetOnDisconnected(f func()) {
// Not supported.
}
type mcuProxyStats struct {
Publishers int64 `json:"publishers"`
Clients int64 `json:"clients"`
Details map[string]*mcuProxyConnectionStats `json:"details"`
}
func (m *mcuProxy) GetStats() interface{} {
details := make(map[string]*mcuProxyConnectionStats)
result := &mcuProxyStats{
Details: details,
}
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
for _, conn := range m.connections {
stats := conn.GetStats()
result.Publishers += stats.Publishers
result.Clients += stats.Clients
details[stats.Url] = stats
}
return result
}
type mcuProxyConnectionsList []*mcuProxyConnection
func (l mcuProxyConnectionsList) Len() int {
return len(l)
}
func (l mcuProxyConnectionsList) Less(i, j int) bool {
return l[i].Load() < l[j].Load()
}
func (l mcuProxyConnectionsList) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
func (l mcuProxyConnectionsList) Sort() {
sort.Sort(l)
}
func ContinentsOverlap(a, b []string) bool {
if len(a) == 0 || len(b) == 0 {
return false
}
for _, checkA := range a {
for _, checkB := range b {
if checkA == checkB {
return true
}
}
}
return false
}
func sortConnectionsForCountry(connections []*mcuProxyConnection, country string) []*mcuProxyConnection {
// Move connections in the same country to the start of the list.
sorted := make(mcuProxyConnectionsList, 0, len(connections))
unprocessed := make(mcuProxyConnectionsList, 0, len(connections))
for _, conn := range connections {
if country == conn.Country() {
sorted = append(sorted, conn)
} else {
unprocessed = append(unprocessed, conn)
}
}
if continents, found := ContinentMap[country]; found && len(unprocessed) > 1 {
remaining := make(mcuProxyConnectionsList, 0, len(unprocessed))
// Next up are connections on the same continent.
for _, conn := range unprocessed {
connCountry := conn.Country()
if IsValidCountry(connCountry) {
connContinents := ContinentMap[connCountry]
if ContinentsOverlap(continents, connContinents) {
sorted = append(sorted, conn)
} else {
remaining = append(remaining, conn)
}
} else {
remaining = append(remaining, conn)
}
}
unprocessed = remaining
}
// Add all other connections by load.
sorted = append(sorted, unprocessed...)
return sorted
}
func (m *mcuProxy) getSortedConnections(initiator McuInitiator) []*mcuProxyConnection {
m.connectionsMu.RLock()
connections := m.connections
m.connectionsMu.RUnlock()
if len(connections) < 2 {
return connections
}
// Connections are re-sorted every <connectionSortRequests> requests or
// every <connectionSortInterval>.
now := time.Now().UnixNano()
if atomic.AddInt64(&m.connRequests, 1)%connectionSortRequests == 0 || atomic.LoadInt64(&m.nextSort) <= now {
atomic.StoreInt64(&m.nextSort, now+int64(connectionSortInterval))
sorted := make(mcuProxyConnectionsList, len(connections))
copy(sorted, connections)
sorted.Sort()
m.connectionsMu.Lock()
m.connections = sorted
m.connectionsMu.Unlock()
connections = sorted
}
if initiator != nil {
if country := initiator.Country(); IsValidCountry(country) {
connections = sortConnectionsForCountry(connections, country)
}
}
return connections
}
func (m *mcuProxy) removePublisher(publisher *mcuProxyPublisher) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.publishers, publisher.id+"|"+publisher.StreamType())
}
func (m *mcuProxy) wakeupWaiters() {
m.mu.RLock()
defer m.mu.RUnlock()
for _, ch := range m.publisherWaiters {
ch <- true
}
}
func (m *mcuProxy) addWaiter(ch chan bool) uint64 {
id := m.publisherWaitersId + 1
m.publisherWaitersId = id
m.publisherWaiters[id] = ch
return id
}
func (m *mcuProxy) removeWaiter(id uint64) {
delete(m.publisherWaiters, id)
}
func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, streamType string, initiator McuInitiator) (McuPublisher, error) {
connections := m.getSortedConnections(initiator)
for _, conn := range connections {
if conn.IsShutdownScheduled() {
continue
}
subctx, cancel := context.WithTimeout(ctx, m.proxyTimeout)
defer cancel()
publisher, err := conn.newPublisher(subctx, listener, id, streamType)
if err != nil {
log.Printf("Could not create %s publisher for %s on %s: %s", streamType, id, conn.url, err)
continue
}
m.mu.Lock()
m.publishers[id+"|"+streamType] = conn
m.mu.Unlock()
m.wakeupWaiters()
return publisher, nil
}
return nil, fmt.Errorf("No MCU connection available")
}
func (m *mcuProxy) getPublisherConnection(ctx context.Context, publisher string, streamType string) *mcuProxyConnection {
m.mu.RLock()
conn := m.publishers[publisher+"|"+streamType]
m.mu.RUnlock()
if conn != nil {
return conn
}
log.Printf("No %s publisher %s found yet, deferring", streamType, publisher)
m.mu.Lock()
defer m.mu.Unlock()
conn = m.publishers[publisher+"|"+streamType]
if conn != nil {
return conn
}
ch := make(chan bool, 1)
id := m.addWaiter(ch)
defer m.removeWaiter(id)
for {
m.mu.Unlock()
select {
case <-ch:
m.mu.Lock()
conn = m.publishers[publisher+"|"+streamType]
if conn != nil {
return conn
}
case <-ctx.Done():
m.mu.Lock()
return nil
}
}
}
func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publisher string, streamType string) (McuSubscriber, error) {
conn := m.getPublisherConnection(ctx, publisher, streamType)
if conn == nil {
return nil, fmt.Errorf("No %s publisher %s found", streamType, publisher)
}
return conn.newSubscriber(ctx, listener, publisher, streamType)
}