diff --git a/clientsession_test.go b/clientsession_test.go index 43de0e4..6d3b9a4 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -131,10 +131,13 @@ func TestBandwidth_Client(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() @@ -148,9 +151,6 @@ func TestBandwidth_Client(t *testing.T) { t.Fatal(err) } - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - hello, err := client.RunUntilHello(ctx) if err != nil { t.Fatal(err) @@ -217,10 +217,13 @@ func TestBandwidth_Backend(t *testing.T) { backend.maxScreenBitrate = 1000 backend.maxStreamBitrate = 2000 + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() @@ -232,9 +235,6 @@ func TestBandwidth_Backend(t *testing.T) { StreamTypeScreen, } - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - for _, streamType := range streamTypes { t.Run(string(streamType), func(t *testing.T) { client := NewTestClient(t, server, hub) diff --git a/hub_test.go b/hub_test.go index d6ab70b..2ebc64f 100644 --- a/hub_test.go +++ b/hub_test.go @@ -4029,19 +4029,19 @@ func TestClientSendOfferPermissions(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() @@ -4170,19 +4170,19 @@ func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() @@ -4263,19 +4263,19 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() @@ -4392,19 +4392,19 @@ func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() @@ -4539,10 +4539,13 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) } + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() @@ -4550,9 +4553,6 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { hub1.SetMcu(mcu) hub2.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() @@ -4965,10 +4965,13 @@ func TestClientSendOffer(t *testing.T) { hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) } + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() @@ -4976,9 +4979,6 @@ func TestClientSendOffer(t *testing.T) { hub1.SetMcu(mcu) hub2.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server1, hub1) defer client1.CloseWithBye() @@ -5073,19 +5073,19 @@ func TestClientUnshareScreen(t *testing.T) { CatchLogForTest(t) hub, _, _, server := CreateHubForTest(t) + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + mcu, err := NewTestMCU() if err != nil { t.Fatal(err) - } else if err := mcu.Start(); err != nil { + } else if err := mcu.Start(ctx); err != nil { t.Fatal(err) } defer mcu.Stop() hub.SetMcu(mcu) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - client1 := NewTestClient(t, server, hub) defer client1.CloseWithBye() diff --git a/janus_client.go b/janus_client.go index 0865f45..b7b33a5 100644 --- a/janus_client.go +++ b/janus_client.go @@ -258,8 +258,8 @@ type JanusGateway struct { // return gateway, nil // } -func NewJanusGateway(wsURL string, listener GatewayListener) (*JanusGateway, error) { - conn, _, err := janusDialer.Dial(wsURL, nil) +func NewJanusGateway(ctx context.Context, wsURL string, listener GatewayListener) (*JanusGateway, error) { + conn, _, err := janusDialer.DialContext(ctx, wsURL, nil) if err != nil { return nil, err } diff --git a/mcu_common.go b/mcu_common.go index 8fbca2b..8ac820c 100644 --- a/mcu_common.go +++ b/mcu_common.go @@ -66,7 +66,7 @@ type McuInitiator interface { } type Mcu interface { - Start() error + Start(ctx context.Context) error Stop() Reload(config *goconf.ConfigFile) diff --git a/mcu_janus.go b/mcu_janus.go index 948f7da..0f70328 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -169,7 +169,7 @@ type mcuJanus struct { func emptyOnConnected() {} func emptyOnDisconnected() {} -func NewMcuJanus(url string, config *goconf.ConfigFile) (Mcu, error) { +func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mcu, error) { maxStreamBitrate, _ := config.GetInt("mcu", "maxstreambitrate") if maxStreamBitrate <= 0 { maxStreamBitrate = defaultMaxStreamBitrate @@ -200,9 +200,11 @@ func NewMcuJanus(url string, config *goconf.ConfigFile) (Mcu, error) { mcu.onConnected.Store(emptyOnConnected) mcu.onDisconnected.Store(emptyOnDisconnected) - mcu.reconnectTimer = time.AfterFunc(mcu.reconnectInterval, mcu.doReconnect) + mcu.reconnectTimer = time.AfterFunc(mcu.reconnectInterval, func() { + mcu.doReconnect(context.Background()) + }) mcu.reconnectTimer.Stop() - if err := mcu.reconnect(); err != nil { + if err := mcu.reconnect(ctx); err != nil { return nil, err } return mcu, nil @@ -230,9 +232,9 @@ func (m *mcuJanus) disconnect() { } } -func (m *mcuJanus) reconnect() error { +func (m *mcuJanus) reconnect(ctx context.Context) error { m.disconnect() - gw, err := NewJanusGateway(m.url, m) + gw, err := NewJanusGateway(ctx, m.url, m) if err != nil { return err } @@ -242,12 +244,12 @@ func (m *mcuJanus) reconnect() error { return nil } -func (m *mcuJanus) doReconnect() { - if err := m.reconnect(); err != nil { +func (m *mcuJanus) doReconnect(ctx context.Context) { + if err := m.reconnect(ctx); err != nil { m.scheduleReconnect(err) return } - if err := m.Start(); err != nil { + if err := m.Start(ctx); err != nil { m.scheduleReconnect(err) return } @@ -296,8 +298,7 @@ func (m *mcuJanus) hasRemotePublisher() bool { return m.version >= 1100 } -func (m *mcuJanus) Start() error { - ctx := context.TODO() +func (m *mcuJanus) Start(ctx context.Context) error { info, err := m.gw.Info(ctx) if err != nil { return err @@ -364,7 +365,7 @@ loop: for { select { case <-ticker.C: - m.sendKeepalive() + m.sendKeepalive(context.Background()) case <-m.closeChan: break loop } @@ -430,8 +431,7 @@ func (m *mcuJanus) GetStats() interface{} { return result } -func (m *mcuJanus) sendKeepalive() { - ctx := context.TODO() +func (m *mcuJanus) sendKeepalive(ctx context.Context) { if _, err := m.session.KeepAlive(ctx); err != nil { log.Println("Could not send keepalive request", err) if e, ok := err.(*janus.ErrorMsg); ok { diff --git a/mcu_proxy.go b/mcu_proxy.go index ea0f3ca..5b34426 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -1395,7 +1395,7 @@ func (m *mcuProxy) loadContinentsMap(config *goconf.ConfigFile) error { return nil } -func (m *mcuProxy) Start() error { +func (m *mcuProxy) Start(ctx context.Context) error { log.Printf("Maximum bandwidth %d bits/sec per publishing stream", m.maxStreamBitrate) log.Printf("Maximum bandwidth %d bits/sec per screensharing stream", m.maxScreenBitrate) diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index b6fe38d..39f12a9 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -713,13 +713,14 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions) *mcuP mcu.Stop() }) - if err := mcu.Start(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if err := mcu.Start(ctx); err != nil { t.Fatal(err) } proxy := mcu.(*mcuProxy) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() if err := proxy.WaitForConnections(ctx); err != nil { t.Fatal(err) diff --git a/mcu_test.go b/mcu_test.go index a1ee9cc..ae1de23 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -50,7 +50,7 @@ func NewTestMCU() (*TestMCU, error) { }, nil } -func (m *TestMCU) Start() error { +func (m *TestMCU) Start(ctx context.Context) error { return nil } diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 972a49f..1d0d4fe 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -386,7 +386,7 @@ func (s *ProxyServer) Start(config *goconf.ConfigFile) error { for { switch mcuType { case signaling.McuTypeJanus: - mcu, err = signaling.NewMcuJanus(s.url, config) + mcu, err = signaling.NewMcuJanus(ctx, s.url, config) if err == nil { signaling.RegisterJanusMcuStats() } @@ -396,7 +396,7 @@ func (s *ProxyServer) Start(config *goconf.ConfigFile) error { if err == nil { mcu.SetOnConnected(s.onMcuConnected) mcu.SetOnDisconnected(s.onMcuDisconnected) - err = mcu.Start() + err = mcu.Start(ctx) if err != nil { log.Printf("Could not create %s MCU at %s: %s", mcuType, s.url, err) } diff --git a/server/main.go b/server/main.go index 9ee6afd..a31a0f5 100644 --- a/server/main.go +++ b/server/main.go @@ -22,6 +22,7 @@ package main import ( + "context" "crypto/tls" "errors" "flag" @@ -240,9 +241,11 @@ func main() { mcuRetryTimer := time.NewTimer(mcuRetry) mcuTypeLoop: for { + // Context should be cancelled on signals but need a way to differentiate later. + ctx := context.TODO() switch mcuType { case signaling.McuTypeJanus: - mcu, err = signaling.NewMcuJanus(mcuUrl, config) + mcu, err = signaling.NewMcuJanus(ctx, mcuUrl, config) signaling.UnregisterProxyMcuStats() signaling.RegisterJanusMcuStats() case signaling.McuTypeProxy: @@ -253,7 +256,7 @@ func main() { log.Fatal("Unsupported MCU type: ", mcuType) } if err == nil { - err = mcu.Start() + err = mcu.Start(ctx) if err != nil { log.Printf("Could not create %s MCU: %s", mcuType, err) }