proxy: Add timeouts to requests to Janus and cancel if session is closed.

This commit is contained in:
Joachim Bauch 2024-10-28 09:24:23 +01:00
commit e1fc062464
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
4 changed files with 62 additions and 22 deletions

View file

@ -158,15 +158,15 @@ func NewClient(ctx context.Context, conn *websocket.Conn, remoteAddress string,
}
client := &Client{
ctx: ctx,
agent: agent,
logRTT: true,
}
client.SetConn(conn, remoteAddress, handler)
client.SetConn(ctx, conn, remoteAddress, handler)
return client, nil
}
func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) {
func (c *Client) SetConn(ctx context.Context, conn *websocket.Conn, remoteAddress string, handler ClientHandler) {
c.ctx = ctx
c.conn = conn
c.addr = remoteAddress
c.SetHandler(handler)

View file

@ -22,6 +22,7 @@
package main
import (
"context"
"sync/atomic"
"time"
@ -37,11 +38,11 @@ type ProxyClient struct {
session atomic.Pointer[ProxySession]
}
func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) {
func NewProxyClient(ctx context.Context, proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) {
client := &ProxyClient{
proxy: proxy,
}
client.SetConn(conn, addr, client)
client.SetConn(ctx, conn, addr, client)
return client, nil
}

View file

@ -62,6 +62,9 @@ const (
initialMcuRetry = time.Second
maxMcuRetry = time.Second * 16
// MCU requests will be cancelled if they take too long.
defaultMcuTimeoutSeconds = 10
updateLoadInterval = time.Second
expireSessionsInterval = 10 * time.Second
@ -103,6 +106,7 @@ type ProxyServer struct {
welcomeMessage string
welcomeMsg *signaling.WelcomeServerMessage
config *goconf.ConfigFile
mcuTimeout time.Duration
url string
mcu signaling.Mcu
@ -319,6 +323,12 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*
maxIncoming, maxOutgoing := getTargetBandwidths(config)
mcuTimeoutSeconds, _ := config.GetInt("mcu", "timeout")
if mcuTimeoutSeconds <= 0 {
mcuTimeoutSeconds = defaultMcuTimeoutSeconds
}
mcuTimeout := time.Duration(mcuTimeoutSeconds) * time.Second
result := &ProxyServer{
version: version,
country: country,
@ -328,7 +338,8 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*
Country: country,
Features: defaultProxyFeatures,
},
config: config,
config: config,
mcuTimeout: mcuTimeout,
shutdownChannel: make(chan struct{}),
@ -634,14 +645,14 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) {
return
}
client, err := NewProxyClient(s, conn, addr)
client, err := NewProxyClient(r.Context(), s, conn, addr)
if err != nil {
log.Printf("Could not create client for %s: %s", addr, err)
return
}
go client.WritePump()
go client.ReadPump()
client.ReadPump()
}
func (s *ProxyServer) clientClosed(client *signaling.Client) {
@ -789,7 +800,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) {
return
}
ctx := context.WithValue(context.Background(), ContextKeySession, session)
ctx := context.WithValue(session.Context(), ContextKeySession, session)
session.MarkUsed()
switch message.Type {
@ -873,8 +884,11 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
return
}
ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout)
defer cancel()
id := uuid.New().String()
publisher, err := s.mcu.NewPublisher(ctx, session, id, cmd.Sid, cmd.StreamType, cmd.Bitrate, cmd.MediaTypes, &emptyInitiator{})
publisher, err := s.mcu.NewPublisher(ctx2, session, id, cmd.Sid, cmd.StreamType, cmd.Bitrate, cmd.MediaTypes, &emptyInitiator{})
if err == context.DeadlineExceeded {
log.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId())
session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingPublisher))
@ -977,7 +991,10 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
log.Printf("Created remote %s subscriber %s as %s for %s on %s", cmd.StreamType, subscriber.Id(), id, session.PublicId(), cmd.RemoteUrl)
} else {
subscriber, err = s.mcu.NewSubscriber(ctx, session, publisherId, cmd.StreamType, &emptyInitiator{})
ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout)
defer cancel()
subscriber, err = s.mcu.NewSubscriber(ctx2, session, publisherId, cmd.StreamType, &emptyInitiator{})
if err != nil {
handleCreateError(err)
return
@ -1083,7 +1100,10 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
return
}
if err := publisher.PublishRemote(ctx, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil {
ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout)
defer cancel()
if err := publisher.PublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil {
var je *janus.ErrorMsg
if !errors.As(err, &je) || je.Err.Code != signaling.JANUS_VIDEOROOM_ERROR_ID_EXISTS {
log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err)
@ -1091,13 +1111,19 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
return
}
if err := publisher.UnpublishRemote(ctx, session.PublicId()); err != nil {
ctx2, cancel = context.WithTimeout(ctx, s.mcuTimeout)
defer cancel()
if err := publisher.UnpublishRemote(ctx2, session.PublicId()); err != nil {
log.Printf("Error unpublishing old %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err)
session.sendMessage(message.NewWrappedErrorServerMessage(err))
return
}
if err := publisher.PublishRemote(ctx, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil {
ctx2, cancel = context.WithTimeout(ctx, s.mcuTimeout)
defer cancel()
if err := publisher.PublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil {
log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err)
session.sendMessage(message.NewWrappedErrorServerMessage(err))
return
@ -1202,7 +1228,10 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s
return
}
mcuClient.SendMessage(ctx, nil, mcuData, func(err error, response map[string]interface{}) {
ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout)
defer cancel()
mcuClient.SendMessage(ctx2, nil, mcuData, func(err error, response map[string]interface{}) {
var responseMsg *signaling.ProxyServerMessage
if err != nil {
log.Printf("Error sending %+v to %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err)

View file

@ -37,10 +37,12 @@ const (
)
type ProxySession struct {
proxy *ProxyServer
id string
sid uint64
lastUsed atomic.Int64
proxy *ProxyServer
id string
sid uint64
lastUsed atomic.Int64
ctx context.Context
closeFunc context.CancelFunc
clientLock sync.Mutex
client *ProxyClient
@ -56,10 +58,13 @@ type ProxySession struct {
}
func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession {
ctx, closeFunc := context.WithCancel(context.Background())
result := &ProxySession{
proxy: proxy,
id: id,
sid: sid,
proxy: proxy,
id: id,
sid: sid,
ctx: ctx,
closeFunc: closeFunc,
publishers: make(map[string]signaling.McuPublisher),
publisherIds: make(map[signaling.McuPublisher]string),
@ -71,6 +76,10 @@ func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession {
return result
}
func (s *ProxySession) Context() context.Context {
return s.ctx
}
func (s *ProxySession) PublicId() string {
return s.id
}
@ -95,6 +104,7 @@ func (s *ProxySession) MarkUsed() {
}
func (s *ProxySession) Close() {
s.closeFunc()
s.clearPublishers()
s.clearSubscribers()
}