From f2f10a01c9ba1c607ed1467d431e821bcca91738 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 15:57:23 +0200 Subject: [PATCH] Automatically reconnect proxy connections if interrupted. --- proxy/proxy_remote.go | 267 +++++++++++++++++++++++++++++++++--------- proxy/proxy_server.go | 10 +- 2 files changed, 212 insertions(+), 65 deletions(-) diff --git a/proxy/proxy_remote.go b/proxy/proxy_remote.go index 3ca6439..838cecc 100644 --- a/proxy/proxy_remote.go +++ b/proxy/proxy_remote.go @@ -27,7 +27,6 @@ import ( "crypto/tls" "encoding/json" "errors" - "io" "log" "net/http" "net/url" @@ -42,23 +41,44 @@ import ( signaling "github.com/strukturag/nextcloud-spreed-signaling" ) +const ( + initialReconnectInterval = 1 * time.Second + maxReconnectInterval = 32 * time.Second + + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 +) + var ( ErrNotConnected = errors.New("not connected") ) type RemoteConnection struct { - mu sync.Mutex - url *url.URL - conn *websocket.Conn + mu sync.Mutex + url *url.URL + conn *websocket.Conn + closer *signaling.Closer + closed atomic.Bool tokenId string tokenKey *rsa.PrivateKey tlsConfig *tls.Config + connectedSince time.Time + reconnectTimer *time.Timer + reconnectInterval atomic.Int64 + msgId atomic.Int64 helloMsgId string sessionId string + pendingMessages []*signaling.ProxyClientMessage messageCallbacks map[string]chan *signaling.ProxyServerMessage } @@ -69,14 +89,21 @@ func NewRemoteConnection(proxyUrl string, tokenId string, tokenKey *rsa.PrivateK } result := &RemoteConnection{ - url: u, + url: u, + closer: signaling.NewCloser(), tokenId: tokenId, tokenKey: tokenKey, tlsConfig: tlsConfig, + reconnectTimer: time.NewTimer(0), + messageCallbacks: make(map[string]chan *signaling.ProxyServerMessage), } + result.reconnectInterval.Store(int64(initialReconnectInterval)) + + go result.writePump() + return result, nil } @@ -84,17 +111,12 @@ func (c *RemoteConnection) String() string { return c.url.String() } -func (c *RemoteConnection) Connect(ctx context.Context) error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.conn != nil { - return nil - } - +func (c *RemoteConnection) reconnect() { u, err := c.url.Parse("proxy") if err != nil { - return err + log.Printf("Could not resolve url to proxy at %s: %s", c, err) + c.scheduleReconnect() + return } if u.Scheme == "http" { u.Scheme = "ws" @@ -107,15 +129,50 @@ func (c *RemoteConnection) Connect(ctx context.Context) error { TLSClientConfig: c.tlsConfig, } - conn, _, err := dialer.DialContext(ctx, u.String(), nil) + conn, _, err := dialer.DialContext(context.TODO(), u.String(), nil) if err != nil { - return err + log.Printf("Error connecting to proxy at %s: %s", c, err) + c.scheduleReconnect() + return } - c.conn = conn - go c.readPump() + log.Printf("Connected to %s", c) + c.closed.Store(false) - return c.sendHello() + c.mu.Lock() + c.connectedSince = time.Now() + c.conn = conn + c.mu.Unlock() + + c.reconnectInterval.Store(int64(initialReconnectInterval)) + + if err := c.sendHello(); err != nil { + log.Printf("Error sending hello request to proxy at %s: %s", c, err) + c.scheduleReconnect() + return + } + + if !c.sendPing() { + return + } + + go c.readPump(conn) +} + +func (c *RemoteConnection) scheduleReconnect() { + if err := c.sendClose(); err != nil && err != ErrNotConnected { + log.Printf("Could not send close message to %s: %s", c, err) + } + c.close() + + interval := c.reconnectInterval.Load() + c.reconnectTimer.Reset(time.Duration(interval)) + + interval = interval * 2 + if interval > int64(maxReconnectInterval) { + interval = int64(maxReconnectInterval) + } + c.reconnectInterval.Store(interval) } func (c *RemoteConnection) sendHello() error { @@ -138,16 +195,40 @@ func (c *RemoteConnection) sendHello() error { msg.Hello.Token = tokenString } - return c.sendMessageLocked(msg) + return c.SendMessage(msg) +} + +func (c *RemoteConnection) sendClose() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return ErrNotConnected + } + + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint + return c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) +} + +func (c *RemoteConnection) close() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } } func (c *RemoteConnection) Close() error { c.mu.Lock() defer c.mu.Unlock() + c.reconnectTimer.Stop() if c.conn == nil { return nil } + c.sendClose() err1 := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}) err2 := c.conn.Close() c.conn = nil @@ -178,51 +259,74 @@ func (c *RemoteConnection) SendMessage(msg *signaling.ProxyClientMessage) error c.mu.Lock() defer c.mu.Unlock() - return c.sendMessageLocked(msg) + return c.sendMessageLocked(context.Background(), msg) } -func (c *RemoteConnection) sendMessageLocked(msg *signaling.ProxyClientMessage) error { +func (c *RemoteConnection) deferMessage(ctx context.Context, msg *signaling.ProxyClientMessage) { + c.pendingMessages = append(c.pendingMessages, msg) + if ctx.Done() != nil { + go func() { + <-ctx.Done() + + c.mu.Lock() + defer c.mu.Unlock() + for idx, m := range c.pendingMessages { + if m == msg { + c.pendingMessages[idx] = nil + break + } + } + }() + } +} + +func (c *RemoteConnection) sendMessageLocked(ctx context.Context, msg *signaling.ProxyClientMessage) error { if c.conn == nil { - return ErrNotConnected + // Defer until connected. + c.deferMessage(ctx, msg) + return nil } + if c.helloMsgId != "" && c.helloMsgId != msg.Id { + // Hello request is still inflight, defer. + c.deferMessage(ctx, msg) + return nil + } + + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint return c.conn.WriteJSON(msg) } -func (c *RemoteConnection) readPump() { +func (c *RemoteConnection) readPump(conn *websocket.Conn) { + defer func() { + if !c.closed.Load() { + c.scheduleReconnect() + } + }() + defer c.close() + for { - c.mu.Lock() - conn := c.conn - c.mu.Unlock() - if conn == nil { - return - } - - msgType, reader, err := conn.NextReader() + msgType, msg, err := conn.ReadMessage() if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { - log.Printf("error reading: %s", err) + if errors.Is(err, websocket.ErrCloseSent) { + break + } else if _, ok := err.(*websocket.CloseError); !ok || websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived) { + log.Printf("Error reading from %s: %v", c, err) } - c.mu.Lock() - c.conn = nil - c.mu.Unlock() - return - } - - body, err := io.ReadAll(reader) - if err != nil { - log.Printf("error reading message: %s", err) - continue + break } if msgType != websocket.TextMessage { - log.Printf("unexpected message type %q (%s)", msgType, string(body)) + log.Printf("unexpected message type %q (%s)", msgType, string(msg)) continue } - var msg signaling.ProxyServerMessage - if err := json.Unmarshal(body, &msg); err != nil { - log.Printf("could not decode message %s: %s", string(body), err) + var message signaling.ProxyServerMessage + if err := json.Unmarshal(msg, &message); err != nil { + log.Printf("could not decode message %s: %s", string(msg), err) continue } @@ -230,18 +334,53 @@ func (c *RemoteConnection) readPump() { helloMsgId := c.helloMsgId c.mu.Unlock() - if helloMsgId != "" && msg.Id == helloMsgId { - c.processHello(&msg) + if helloMsgId != "" && message.Id == helloMsgId { + c.processHello(&message) } else { - c.processMessage(&msg) + c.processMessage(&message) + } + } +} + +func (c *RemoteConnection) sendPing() bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.conn == nil { + return false + } + + now := time.Now() + msg := strconv.FormatInt(now.UnixNano(), 10) + c.conn.SetWriteDeadline(now.Add(writeWait)) // nolint + if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { + log.Printf("Could not send ping to proxy at %s: %v", c, err) + go c.scheduleReconnect() + return false + } + + return true +} + +func (c *RemoteConnection) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + }() + + defer c.reconnectTimer.Stop() + for { + select { + case <-c.reconnectTimer.C: + c.reconnect() + case <-ticker.C: + c.sendPing() + case <-c.closer.C: + return } } } func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { - c.mu.Lock() - defer c.mu.Unlock() - c.helloMsgId = "" switch msg.Type { case "error": @@ -250,13 +389,13 @@ func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { c.sessionId = "" if err := c.sendHello(); err != nil { log.Printf("Could not send hello request to %s: %s", c, err) - // TODO: c.scheduleReconnect() + c.scheduleReconnect() } return } log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) - // TODO: c.scheduleReconnect() + c.scheduleReconnect() case "hello": resumed := c.sessionId == msg.Hello.SessionId c.sessionId = msg.Hello.SessionId @@ -274,9 +413,21 @@ func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { } else { log.Printf("Received session %s from %s", c.sessionId, c) } + + pending := c.pendingMessages + c.pendingMessages = nil + for _, m := range pending { + if m == nil { + continue + } + + if err := c.sendMessageLocked(context.Background(), m); err != nil { + log.Printf("Could not send pending message %+v to %s: %s", m, c, err) + } + } default: log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c) - // TODO: c.scheduleReconnect() + c.scheduleReconnect() } } @@ -315,7 +466,7 @@ func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.Pr c.mu.Lock() defer c.mu.Unlock() - if err := c.sendMessageLocked(msg); err != nil { + if err := c.sendMessageLocked(ctx, msg); err != nil { return nil, err } ch := make(chan *signaling.ProxyServerMessage, 1) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 705ddb2..2552774 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -725,7 +725,7 @@ func (p *proxyRemotePublisher) PublisherId() string { } func (p *proxyRemotePublisher) StartPublishing(ctx context.Context, publisher signaling.McuRemotePublisherProperties) error { - conn, err := p.proxy.getRemoteConnection(ctx, p.remoteUrl) + conn, err := p.proxy.getRemoteConnection(p.remoteUrl) if err != nil { return err } @@ -747,7 +747,7 @@ func (p *proxyRemotePublisher) StartPublishing(ctx context.Context, publisher si } func (p *proxyRemotePublisher) GetStreams(ctx context.Context) ([]signaling.PublisherStream, error) { - conn, err := p.proxy.getRemoteConnection(ctx, p.remoteUrl) + conn, err := p.proxy.getRemoteConnection(p.remoteUrl) if err != nil { return nil, err } @@ -1352,7 +1352,7 @@ func (s *ProxyServer) metricsHandler(w http.ResponseWriter, r *http.Request) { promhttp.Handler().ServeHTTP(w, r) } -func (s *ProxyServer) getRemoteConnection(ctx context.Context, url string) (*RemoteConnection, error) { +func (s *ProxyServer) getRemoteConnection(url string) (*RemoteConnection, error) { s.remoteConnectionsLock.Lock() defer s.remoteConnectionsLock.Unlock() @@ -1366,10 +1366,6 @@ func (s *ProxyServer) getRemoteConnection(ctx context.Context, url string) (*Rem return nil, err } - if err := conn.Connect(ctx); err != nil { - return nil, err - } - s.remoteConnections[url] = conn return conn, nil }