Pass contexts when creating / starting MCUs.

This commit is contained in:
Joachim Bauch 2024-05-16 16:53:41 +02:00
parent ae7c498cb9
commit 35cf2cafc2
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
10 changed files with 65 additions and 61 deletions

View file

@ -131,10 +131,13 @@ func TestBandwidth_Client(t *testing.T) {
CatchLogForTest(t)
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
@ -148,9 +151,6 @@ func TestBandwidth_Client(t *testing.T) {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
@ -217,10 +217,13 @@ func TestBandwidth_Backend(t *testing.T) {
backend.maxScreenBitrate = 1000
backend.maxStreamBitrate = 2000
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
@ -232,9 +235,6 @@ func TestBandwidth_Backend(t *testing.T) {
StreamTypeScreen,
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
for _, streamType := range streamTypes {
t.Run(string(streamType), func(t *testing.T) {
client := NewTestClient(t, server, hub)

View file

@ -4011,19 +4011,19 @@ func TestClientSendOfferPermissions(t *testing.T) {
CatchLogForTest(t)
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
hub.SetMcu(mcu)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewTestClient(t, server, hub)
defer client1.CloseWithBye()
@ -4152,19 +4152,19 @@ func TestClientSendOfferPermissionsAudioOnly(t *testing.T) {
CatchLogForTest(t)
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
hub.SetMcu(mcu)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewTestClient(t, server, hub)
defer client1.CloseWithBye()
@ -4245,19 +4245,19 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) {
CatchLogForTest(t)
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
hub.SetMcu(mcu)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewTestClient(t, server, hub)
defer client1.CloseWithBye()
@ -4374,19 +4374,19 @@ func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) {
CatchLogForTest(t)
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
hub.SetMcu(mcu)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewTestClient(t, server, hub)
defer client1.CloseWithBye()
@ -4521,10 +4521,13 @@ func TestClientRequestOfferNotInRoom(t *testing.T) {
hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
@ -4532,9 +4535,6 @@ func TestClientRequestOfferNotInRoom(t *testing.T) {
hub1.SetMcu(mcu)
hub2.SetMcu(mcu)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewTestClient(t, server1, hub1)
defer client1.CloseWithBye()
@ -4947,10 +4947,13 @@ func TestClientSendOffer(t *testing.T) {
hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
@ -4958,9 +4961,6 @@ func TestClientSendOffer(t *testing.T) {
hub1.SetMcu(mcu)
hub2.SetMcu(mcu)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewTestClient(t, server1, hub1)
defer client1.CloseWithBye()
@ -5055,19 +5055,19 @@ func TestClientUnshareScreen(t *testing.T) {
CatchLogForTest(t)
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
mcu, err := NewTestMCU()
if err != nil {
t.Fatal(err)
} else if err := mcu.Start(); err != nil {
} else if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
defer mcu.Stop()
hub.SetMcu(mcu)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
client1 := NewTestClient(t, server, hub)
defer client1.CloseWithBye()

View file

@ -258,8 +258,8 @@ type JanusGateway struct {
// return gateway, nil
// }
func NewJanusGateway(wsURL string, listener GatewayListener) (*JanusGateway, error) {
conn, _, err := janusDialer.Dial(wsURL, nil)
func NewJanusGateway(ctx context.Context, wsURL string, listener GatewayListener) (*JanusGateway, error) {
conn, _, err := janusDialer.DialContext(ctx, wsURL, nil)
if err != nil {
return nil, err
}

View file

@ -66,7 +66,7 @@ type McuInitiator interface {
}
type Mcu interface {
Start() error
Start(ctx context.Context) error
Stop()
Reload(config *goconf.ConfigFile)

View file

@ -169,7 +169,7 @@ type mcuJanus struct {
func emptyOnConnected() {}
func emptyOnDisconnected() {}
func NewMcuJanus(url string, config *goconf.ConfigFile) (Mcu, error) {
func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mcu, error) {
maxStreamBitrate, _ := config.GetInt("mcu", "maxstreambitrate")
if maxStreamBitrate <= 0 {
maxStreamBitrate = defaultMaxStreamBitrate
@ -200,9 +200,11 @@ func NewMcuJanus(url string, config *goconf.ConfigFile) (Mcu, error) {
mcu.onConnected.Store(emptyOnConnected)
mcu.onDisconnected.Store(emptyOnDisconnected)
mcu.reconnectTimer = time.AfterFunc(mcu.reconnectInterval, mcu.doReconnect)
mcu.reconnectTimer = time.AfterFunc(mcu.reconnectInterval, func() {
mcu.doReconnect(context.Background())
})
mcu.reconnectTimer.Stop()
if err := mcu.reconnect(); err != nil {
if err := mcu.reconnect(ctx); err != nil {
return nil, err
}
return mcu, nil
@ -230,9 +232,9 @@ func (m *mcuJanus) disconnect() {
}
}
func (m *mcuJanus) reconnect() error {
func (m *mcuJanus) reconnect(ctx context.Context) error {
m.disconnect()
gw, err := NewJanusGateway(m.url, m)
gw, err := NewJanusGateway(ctx, m.url, m)
if err != nil {
return err
}
@ -242,12 +244,12 @@ func (m *mcuJanus) reconnect() error {
return nil
}
func (m *mcuJanus) doReconnect() {
if err := m.reconnect(); err != nil {
func (m *mcuJanus) doReconnect(ctx context.Context) {
if err := m.reconnect(ctx); err != nil {
m.scheduleReconnect(err)
return
}
if err := m.Start(); err != nil {
if err := m.Start(ctx); err != nil {
m.scheduleReconnect(err)
return
}
@ -296,8 +298,7 @@ func (m *mcuJanus) hasRemotePublisher() bool {
return m.version >= 1100
}
func (m *mcuJanus) Start() error {
ctx := context.TODO()
func (m *mcuJanus) Start(ctx context.Context) error {
info, err := m.gw.Info(ctx)
if err != nil {
return err
@ -364,7 +365,7 @@ loop:
for {
select {
case <-ticker.C:
m.sendKeepalive()
m.sendKeepalive(context.Background())
case <-m.closeChan:
break loop
}
@ -430,8 +431,7 @@ func (m *mcuJanus) GetStats() interface{} {
return result
}
func (m *mcuJanus) sendKeepalive() {
ctx := context.TODO()
func (m *mcuJanus) sendKeepalive(ctx context.Context) {
if _, err := m.session.KeepAlive(ctx); err != nil {
log.Println("Could not send keepalive request", err)
if e, ok := err.(*janus.ErrorMsg); ok {

View file

@ -1395,7 +1395,7 @@ func (m *mcuProxy) loadContinentsMap(config *goconf.ConfigFile) error {
return nil
}
func (m *mcuProxy) Start() error {
func (m *mcuProxy) Start(ctx context.Context) error {
log.Printf("Maximum bandwidth %d bits/sec per publishing stream", m.maxStreamBitrate)
log.Printf("Maximum bandwidth %d bits/sec per screensharing stream", m.maxScreenBitrate)

View file

@ -713,13 +713,14 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP
mcu.Stop()
})
if err := mcu.Start(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if err := mcu.Start(ctx); err != nil {
t.Fatal(err)
}
proxy := mcu.(*mcuProxy)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if err := proxy.WaitForConnections(ctx); err != nil {
t.Fatal(err)

View file

@ -50,7 +50,7 @@ func NewTestMCU() (*TestMCU, error) {
}, nil
}
func (m *TestMCU) Start() error {
func (m *TestMCU) Start(ctx context.Context) error {
return nil
}

View file

@ -386,7 +386,7 @@ func (s *ProxyServer) Start(config *goconf.ConfigFile) error {
for {
switch mcuType {
case signaling.McuTypeJanus:
mcu, err = signaling.NewMcuJanus(s.url, config)
mcu, err = signaling.NewMcuJanus(ctx, s.url, config)
if err == nil {
signaling.RegisterJanusMcuStats()
}
@ -396,7 +396,7 @@ func (s *ProxyServer) Start(config *goconf.ConfigFile) error {
if err == nil {
mcu.SetOnConnected(s.onMcuConnected)
mcu.SetOnDisconnected(s.onMcuDisconnected)
err = mcu.Start()
err = mcu.Start(ctx)
if err != nil {
log.Printf("Could not create %s MCU at %s: %s", mcuType, s.url, err)
}

View file

@ -22,6 +22,7 @@
package main
import (
"context"
"crypto/tls"
"errors"
"flag"
@ -240,9 +241,11 @@ func main() {
mcuRetryTimer := time.NewTimer(mcuRetry)
mcuTypeLoop:
for {
// Context should be cancelled on signals but need a way to differentiate later.
ctx := context.TODO()
switch mcuType {
case signaling.McuTypeJanus:
mcu, err = signaling.NewMcuJanus(mcuUrl, config)
mcu, err = signaling.NewMcuJanus(ctx, mcuUrl, config)
signaling.UnregisterProxyMcuStats()
signaling.RegisterJanusMcuStats()
case signaling.McuTypeProxy:
@ -253,7 +256,7 @@ func main() {
log.Fatal("Unsupported MCU type: ", mcuType)
}
if err == nil {
err = mcu.Start()
err = mcu.Start(ctx)
if err != nil {
log.Printf("Could not create %s MCU: %s", mcuType, err)
}