From 6ca41dee618ddade93cf07d59601df63ef054032 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 19 Nov 2025 16:03:05 +0100 Subject: [PATCH] Don't use global logger. --- async_events.go | 11 +- async_events_nats.go | 43 +++--- async_events_test.go | 8 +- backend_client.go | 31 ++-- backend_client_test.go | 25 ++-- backend_configuration.go | 6 +- backend_configuration_test.go | 73 +++++---- backend_server.go | 109 +++++++------- backend_server_test.go | 113 +++++++------- backend_storage_etcd.go | 27 ++-- backend_storage_etcd_test.go | 4 +- backend_storage_static.go | 51 +++---- capabilities.go | 46 +++--- capabilities_test.go | 45 +++--- certificate_reloader.go | 25 ++-- client.go | 37 ++--- clientsession.go | 74 ++++----- clientsession_test.go | 13 +- deferred_executor.go | 11 +- deferred_executor_test.go | 19 ++- dns_monitor.go | 7 +- dns_monitor_test.go | 3 +- etcd_client.go | 27 ++-- etcd_client_test.go | 32 ++-- federation.go | 42 +++--- federation_test.go | 20 +-- file_watcher.go | 11 +- file_watcher_test.go | 27 ++-- geoip.go | 30 ++-- geoip_test.go | 16 +- grpc_client.go | 83 +++++----- grpc_client_test.go | 29 ++-- grpc_common.go | 11 +- grpc_remote_client.go | 9 +- grpc_server.go | 24 +-- grpc_server_test.go | 11 +- hub.go | 250 ++++++++++++++++--------------- hub_test.go | 134 ++++------------- mcu_common.go | 7 +- mcu_janus.go | 102 +++++++------ mcu_janus_client.go | 9 +- mcu_janus_events_handler.go | 42 +++--- mcu_janus_events_handler_test.go | 4 +- mcu_janus_publisher.go | 43 +++--- mcu_janus_publisher_test.go | 1 - mcu_janus_remote_publisher.go | 33 ++-- mcu_janus_remote_subscriber.go | 23 ++- mcu_janus_subscriber.go | 43 +++--- mcu_janus_test.go | 20 +-- mcu_proxy.go | 216 +++++++++++++------------- mcu_proxy_test.go | 86 ++++------- mcu_test.go | 15 +- natsclient.go | 29 ++-- natsclient_loopback.go | 9 +- natsclient_loopback_test.go | 3 +- natsclient_test.go | 9 +- proxy/main.go | 40 ++--- proxy/proxy_remote.go | 53 +++---- proxy/proxy_server.go | 153 +++++++++---------- proxy/proxy_server_test.go | 20 +-- proxy/proxy_session.go | 15 +- proxy/proxy_tokens_etcd.go | 17 ++- proxy/proxy_tokens_etcd_test.go | 4 +- proxy/proxy_tokens_static.go | 20 +-- proxy_config_etcd.go | 31 ++-- proxy_config_etcd_test.go | 6 +- proxy_config_static.go | 15 +- proxy_config_static_test.go | 5 +- remotesession.go | 7 +- room.go | 67 +++++---- room_ping.go | 27 ++-- room_ping_test.go | 40 ++--- room_test.go | 35 +++-- roomsessions_builtin.go | 6 +- server/main.go | 126 ++++++++-------- test_helpers.go | 23 --- throttle.go | 7 +- throttle_test.go | 21 ++- transient_data_test.go | 1 - virtualsession.go | 18 ++- virtualsession_test.go | 4 - 81 files changed, 1494 insertions(+), 1498 deletions(-) diff --git a/async_events.go b/async_events.go index d8bb0c9..6598cd3 100644 --- a/async_events.go +++ b/async_events.go @@ -21,7 +21,10 @@ */ package signaling -import "sync" +import ( + "context" + "sync" +) type AsyncBackendRoomEventListener interface { ProcessBackendRoomRequest(message *AsyncMessage) @@ -60,13 +63,13 @@ type AsyncEvents interface { PublishSessionMessage(sessionId PublicSessionId, backend *Backend, message *AsyncMessage) error } -func NewAsyncEvents(url string) (AsyncEvents, error) { - client, err := NewNatsClient(url) +func NewAsyncEvents(ctx context.Context, url string) (AsyncEvents, error) { + client, err := NewNatsClient(ctx, url) if err != nil { return nil, err } - return NewAsyncEventsNats(client) + return NewAsyncEventsNats(LoggerFromContext(ctx), client) } type asyncBackendRoomSubscriber struct { diff --git a/async_events_nats.go b/async_events_nats.go index b5c3ae5..ab9fe13 100644 --- a/async_events_nats.go +++ b/async_events_nats.go @@ -23,7 +23,6 @@ package signaling import ( "fmt" - "log" "sync" "time" @@ -61,6 +60,7 @@ func GetSubjectForSessionId(sessionId PublicSessionId, backend *Backend) string type asyncSubscriberNats struct { key string client NatsClient + logger Logger receiver chan *nats.Msg closeChan chan struct{} @@ -69,7 +69,7 @@ type asyncSubscriberNats struct { processMessage func(*nats.Msg) } -func newAsyncSubscriberNats(key string, client NatsClient) (*asyncSubscriberNats, error) { +func newAsyncSubscriberNats(logger Logger, key string, client NatsClient) (*asyncSubscriberNats, error) { receiver := make(chan *nats.Msg, 64) sub, err := client.Subscribe(key, receiver) if err != nil { @@ -79,6 +79,7 @@ func newAsyncSubscriberNats(key string, client NatsClient) (*asyncSubscriberNats result := &asyncSubscriberNats{ key: key, client: client, + logger: logger, receiver: receiver, closeChan: make(chan struct{}), @@ -90,7 +91,7 @@ func newAsyncSubscriberNats(key string, client NatsClient) (*asyncSubscriberNats func (s *asyncSubscriberNats) run() { defer func() { if err := s.subscription.Unsubscribe(); err != nil { - log.Printf("Error unsubscribing %s: %s", s.key, err) + s.logger.Printf("Error unsubscribing %s: %s", s.key, err) } }() @@ -116,8 +117,8 @@ type asyncBackendRoomSubscriberNats struct { asyncBackendRoomSubscriber } -func newAsyncBackendRoomSubscriberNats(key string, client NatsClient) (*asyncBackendRoomSubscriberNats, error) { - sub, err := newAsyncSubscriberNats(key, client) +func newAsyncBackendRoomSubscriberNats(logger Logger, key string, client NatsClient) (*asyncBackendRoomSubscriberNats, error) { + sub, err := newAsyncSubscriberNats(logger, key, client) if err != nil { return nil, err } @@ -133,7 +134,7 @@ func newAsyncBackendRoomSubscriberNats(key string, client NatsClient) (*asyncBac func (s *asyncBackendRoomSubscriberNats) doProcessMessage(msg *nats.Msg) { var message AsyncMessage if err := s.client.Decode(msg, &message); err != nil { - log.Printf("Could not decode NATS message %+v, %s", msg, err) + s.logger.Printf("Could not decode NATS message %+v, %s", msg, err) return } @@ -145,8 +146,8 @@ type asyncRoomSubscriberNats struct { *asyncSubscriberNats } -func newAsyncRoomSubscriberNats(key string, client NatsClient) (*asyncRoomSubscriberNats, error) { - sub, err := newAsyncSubscriberNats(key, client) +func newAsyncRoomSubscriberNats(logger Logger, key string, client NatsClient) (*asyncRoomSubscriberNats, error) { + sub, err := newAsyncSubscriberNats(logger, key, client) if err != nil { return nil, err } @@ -162,7 +163,7 @@ func newAsyncRoomSubscriberNats(key string, client NatsClient) (*asyncRoomSubscr func (s *asyncRoomSubscriberNats) doProcessMessage(msg *nats.Msg) { var message AsyncMessage if err := s.client.Decode(msg, &message); err != nil { - log.Printf("Could not decode nats message %+v, %s", msg, err) + s.logger.Printf("Could not decode NATS message %+v, %s", msg, err) return } @@ -174,8 +175,8 @@ type asyncUserSubscriberNats struct { asyncUserSubscriber } -func newAsyncUserSubscriberNats(key string, client NatsClient) (*asyncUserSubscriberNats, error) { - sub, err := newAsyncSubscriberNats(key, client) +func newAsyncUserSubscriberNats(logger Logger, key string, client NatsClient) (*asyncUserSubscriberNats, error) { + sub, err := newAsyncSubscriberNats(logger, key, client) if err != nil { return nil, err } @@ -191,7 +192,7 @@ func newAsyncUserSubscriberNats(key string, client NatsClient) (*asyncUserSubscr func (s *asyncUserSubscriberNats) doProcessMessage(msg *nats.Msg) { var message AsyncMessage if err := s.client.Decode(msg, &message); err != nil { - log.Printf("Could not decode nats message %+v, %s", msg, err) + s.logger.Printf("Could not decode NATS message %+v, %s", msg, err) return } @@ -203,8 +204,8 @@ type asyncSessionSubscriberNats struct { asyncSessionSubscriber } -func newAsyncSessionSubscriberNats(key string, client NatsClient) (*asyncSessionSubscriberNats, error) { - sub, err := newAsyncSubscriberNats(key, client) +func newAsyncSessionSubscriberNats(logger Logger, key string, client NatsClient) (*asyncSessionSubscriberNats, error) { + sub, err := newAsyncSubscriberNats(logger, key, client) if err != nil { return nil, err } @@ -220,7 +221,7 @@ func newAsyncSessionSubscriberNats(key string, client NatsClient) (*asyncSession func (s *asyncSessionSubscriberNats) doProcessMessage(msg *nats.Msg) { var message AsyncMessage if err := s.client.Decode(msg, &message); err != nil { - log.Printf("Could not decode nats message %+v, %s", msg, err) + s.logger.Printf("Could not decode NATS message %+v, %s", msg, err) return } @@ -230,6 +231,7 @@ func (s *asyncSessionSubscriberNats) doProcessMessage(msg *nats.Msg) { type asyncEventsNats struct { mu sync.Mutex client NatsClient + logger Logger // +checklocksignore // +checklocks:mu backendRoomSubscriptions map[string]*asyncBackendRoomSubscriberNats @@ -241,9 +243,10 @@ type asyncEventsNats struct { sessionSubscriptions map[string]*asyncSessionSubscriberNats } -func NewAsyncEventsNats(client NatsClient) (AsyncEvents, error) { +func NewAsyncEventsNats(logger Logger, client NatsClient) (AsyncEvents, error) { events := &asyncEventsNats{ client: client, + logger: logger, backendRoomSubscriptions: make(map[string]*asyncBackendRoomSubscriberNats), roomSubscriptions: make(map[string]*asyncRoomSubscriberNats), @@ -328,7 +331,7 @@ func (e *asyncEventsNats) RegisterBackendRoomListener(roomId string, backend *Ba sub, found := e.backendRoomSubscriptions[key] if !found { var err error - if sub, err = newAsyncBackendRoomSubscriberNats(key, e.client); err != nil { + if sub, err = newAsyncBackendRoomSubscriberNats(e.logger, key, e.client); err != nil { return err } @@ -362,7 +365,7 @@ func (e *asyncEventsNats) RegisterRoomListener(roomId string, backend *Backend, sub, found := e.roomSubscriptions[key] if !found { var err error - if sub, err = newAsyncRoomSubscriberNats(key, e.client); err != nil { + if sub, err = newAsyncRoomSubscriberNats(e.logger, key, e.client); err != nil { return err } @@ -396,7 +399,7 @@ func (e *asyncEventsNats) RegisterUserListener(roomId string, backend *Backend, sub, found := e.userSubscriptions[key] if !found { var err error - if sub, err = newAsyncUserSubscriberNats(key, e.client); err != nil { + if sub, err = newAsyncUserSubscriberNats(e.logger, key, e.client); err != nil { return err } @@ -430,7 +433,7 @@ func (e *asyncEventsNats) RegisterSessionListener(sessionId PublicSessionId, bac sub, found := e.sessionSubscriptions[key] if !found { var err error - if sub, err = newAsyncSessionSubscriberNats(key, e.client); err != nil { + if sub, err = newAsyncSessionSubscriberNats(e.logger, key, e.client); err != nil { return err } diff --git a/async_events_test.go b/async_events_test.go index 3ded5a0..02d6145 100644 --- a/async_events_test.go +++ b/async_events_test.go @@ -50,8 +50,10 @@ func getAsyncEventsForTest(t *testing.T) AsyncEvents { } func getRealAsyncEventsForTest(t *testing.T) AsyncEvents { + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) server, _ := startLocalNatsServer(t) - events, err := NewAsyncEvents(server.ClientURL()) + events, err := NewAsyncEvents(ctx, server.ClientURL()) if err != nil { require.NoError(t, err) } @@ -59,7 +61,9 @@ func getRealAsyncEventsForTest(t *testing.T) AsyncEvents { } func getLoopbackAsyncEventsForTest(t *testing.T) AsyncEvents { - events, err := NewAsyncEvents(NatsLoopbackUrl) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + events, err := NewAsyncEvents(ctx, NatsLoopbackUrl) if err != nil { require.NoError(t, err) } diff --git a/backend_client.go b/backend_client.go index 769d02c..54ee18f 100644 --- a/backend_client.go +++ b/backend_client.go @@ -26,7 +26,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "net/http" "net/url" "strings" @@ -57,15 +56,16 @@ type BackendClient struct { buffers BufferPool } -func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string, etcdClient *EtcdClient) (*BackendClient, error) { - backends, err := NewBackendConfiguration(config, etcdClient) +func NewBackendClient(ctx context.Context, config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string, etcdClient *EtcdClient) (*BackendClient, error) { + logger := LoggerFromContext(ctx) + backends, err := NewBackendConfiguration(logger, config, etcdClient) if err != nil { return nil, err } skipverify, _ := config.GetBool("backend", "skipverify") if skipverify { - log.Println("WARNING: Backend verification is disabled!") + logger.Println("WARNING: Backend verification is disabled!") } pool, err := NewHttpClientPool(maxConcurrentRequestsPerHost, skipverify) @@ -118,6 +118,7 @@ func isOcsRequest(u *url.URL) bool { // PerformJSONRequest sends a JSON POST request to the given url and decodes // the result into "response". func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, request any, response any) error { + logger := LoggerFromContext(ctx) if u == nil { return fmt.Errorf("no url passed to perform JSON request %+v", request) } @@ -139,21 +140,21 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ c, pool, err := b.pool.Get(ctx, u) if err != nil { - log.Printf("Could not get client for host %s: %s", u.Host, err) + logger.Printf("Could not get client for host %s: %s", u.Host, err) return err } defer pool.Put(c) data, err := b.buffers.MarshalAsJSON(request) if err != nil { - log.Printf("Could not marshal request %+v: %s", request, err) + logger.Printf("Could not marshal request %+v: %s", request, err) return err } defer b.buffers.Put(data) req, err := http.NewRequestWithContext(ctx, "POST", requestUrl.String(), data) if err != nil { - log.Printf("Could not create request to %s: %s", requestUrl, err) + logger.Printf("Could not create request to %s: %s", requestUrl, err) return err } req.Header.Set("Content-Type", "application/json") @@ -181,21 +182,21 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ } else { statsBackendClientError.WithLabelValues(backend.Id(), "unknown").Inc() } - log.Printf("Could not send request %s to %s: %s", data.String(), req.URL, err) + logger.Printf("Could not send request %s to %s: %s", data.String(), req.URL, err) return err } defer resp.Body.Close() ct := resp.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { - log.Printf("Received unsupported content-type from %s for %s: %s (%s)", req.URL, data.String(), ct, resp.Status) + logger.Printf("Received unsupported content-type from %s for %s: %s (%s)", req.URL, data.String(), ct, resp.Status) statsBackendClientError.WithLabelValues(backend.Id(), "invalid_content_type").Inc() return ErrUnsupportedContentType } body, err := b.buffers.ReadAll(resp.Body) if err != nil { - log.Printf("Could not read response body from %s for %s: %s", req.URL, data.String(), err) + logger.Printf("Could not read response body from %s for %s: %s", req.URL, data.String(), err) statsBackendClientError.WithLabelValues(backend.Id(), "error_reading_body").Inc() return err } @@ -213,29 +214,29 @@ func (b *BackendClient) PerformJSONRequest(ctx context.Context, u *url.URL, requ // } var ocs OcsResponse if err := json.Unmarshal(body.Bytes(), &ocs); err != nil { - log.Printf("Could not decode OCS response %s from %s: %s", body.String(), req.URL, err) + logger.Printf("Could not decode OCS response %s from %s: %s", body.String(), req.URL, err) statsBackendClientError.WithLabelValues(backend.Id(), "error_decoding_ocs").Inc() return err } else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 { - log.Printf("Incomplete OCS response %s from %s", body.String(), req.URL) + logger.Printf("Incomplete OCS response %s from %s", body.String(), req.URL) statsBackendClientError.WithLabelValues(backend.Id(), "error_incomplete_ocs").Inc() return ErrIncompleteResponse } switch ocs.Ocs.Meta.StatusCode { case http.StatusTooManyRequests: - log.Printf("Throttled OCS response %s from %s", body.String(), req.URL) + logger.Printf("Throttled OCS response %s from %s", body.String(), req.URL) statsBackendClientError.WithLabelValues(backend.Id(), "throttled").Inc() return ErrThrottledResponse } if err := json.Unmarshal(ocs.Ocs.Data, response); err != nil { - log.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), req.URL, err) + logger.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), req.URL, err) statsBackendClientError.WithLabelValues(backend.Id(), "error_decoding_ocs_data").Inc() return err } } else if err := json.Unmarshal(body.Bytes(), response); err != nil { - log.Printf("Could not decode response body %s from %s: %s", body.String(), req.URL, err) + logger.Printf("Could not decode response body %s from %s: %s", body.String(), req.URL, err) statsBackendClientError.WithLabelValues(backend.Id(), "error_decoding_body").Inc() return err } diff --git a/backend_client_test.go b/backend_client_test.go index d0a4d77..71c7824 100644 --- a/backend_client_test.go +++ b/backend_client_test.go @@ -22,7 +22,6 @@ package signaling import ( - "context" "encoding/json" "io" "net/http" @@ -67,7 +66,8 @@ func returnOCS(t *testing.T, w http.ResponseWriter, body []byte) { func TestPostOnRedirect(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) r := mux.NewRouter() r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { @@ -96,10 +96,9 @@ func TestPostOnRedirect(t *testing.T) { if u.Scheme == "http" { config.AddOption("backend", "allowhttp", "true") } - client, err := NewBackendClient(config, 1, "0.0", nil) + client, err := NewBackendClient(ctx, config, 1, "0.0", nil) require.NoError(err) - ctx := context.Background() request := map[string]string{ "foo": "bar", } @@ -114,7 +113,8 @@ func TestPostOnRedirect(t *testing.T) { func TestPostOnRedirectDifferentHost(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) r := mux.NewRouter() r.HandleFunc("/ocs/v2.php/one", func(w http.ResponseWriter, r *http.Request) { @@ -132,10 +132,9 @@ func TestPostOnRedirectDifferentHost(t *testing.T) { if u.Scheme == "http" { config.AddOption("backend", "allowhttp", "true") } - client, err := NewBackendClient(config, 1, "0.0", nil) + client, err := NewBackendClient(ctx, config, 1, "0.0", nil) require.NoError(err) - ctx := context.Background() request := map[string]string{ "foo": "bar", } @@ -151,7 +150,8 @@ func TestPostOnRedirectDifferentHost(t *testing.T) { func TestPostOnRedirectStatusFound(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) r := mux.NewRouter() @@ -177,10 +177,9 @@ func TestPostOnRedirectStatusFound(t *testing.T) { if u.Scheme == "http" { config.AddOption("backend", "allowhttp", "true") } - client, err := NewBackendClient(config, 1, "0.0", nil) + client, err := NewBackendClient(ctx, config, 1, "0.0", nil) require.NoError(err) - ctx := context.Background() request := map[string]string{ "foo": "bar", } @@ -193,7 +192,8 @@ func TestPostOnRedirectStatusFound(t *testing.T) { func TestHandleThrottled(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) r := mux.NewRouter() @@ -212,10 +212,9 @@ func TestHandleThrottled(t *testing.T) { if u.Scheme == "http" { config.AddOption("backend", "allowhttp", "true") } - client, err := NewBackendClient(config, 1, "0.0", nil) + client, err := NewBackendClient(ctx, config, 1, "0.0", nil) require.NoError(err) - ctx := context.Background() request := map[string]string{ "foo": "bar", } diff --git a/backend_configuration.go b/backend_configuration.go index d79bd7e..d9be669 100644 --- a/backend_configuration.go +++ b/backend_configuration.go @@ -224,7 +224,7 @@ type BackendConfiguration struct { storage BackendStorage } -func NewBackendConfiguration(config *goconf.ConfigFile, etcdClient *EtcdClient) (*BackendConfiguration, error) { +func NewBackendConfiguration(logger Logger, config *goconf.ConfigFile, etcdClient *EtcdClient) (*BackendConfiguration, error) { backendType, _ := config.GetString("backend", "backendtype") if backendType == "" { backendType = DefaultBackendType @@ -236,9 +236,9 @@ func NewBackendConfiguration(config *goconf.ConfigFile, etcdClient *EtcdClient) var err error switch backendType { case BackendTypeStatic: - storage, err = NewBackendStorageStatic(config) + storage, err = NewBackendStorageStatic(logger, config) case BackendTypeEtcd: - storage, err = NewBackendStorageEtcd(config, etcdClient) + storage, err = NewBackendStorageEtcd(logger, config, etcdClient) default: err = fmt.Errorf("unknown backend type: %s", backendType) } diff --git a/backend_configuration_test.go b/backend_configuration_test.go index d5d7bdd..e22d160 100644 --- a/backend_configuration_test.go +++ b/backend_configuration_test.go @@ -81,7 +81,7 @@ func testBackends(t *testing.T, config *BackendConfiguration, valid_urls [][]str } func TestIsUrlAllowed_Compat(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) // Old-style configuration valid_urls := []string{ "http://domain.invalid", @@ -96,13 +96,13 @@ func TestIsUrlAllowed_Compat(t *testing.T) { config.AddOption("backend", "allowed", "domain.invalid") config.AddOption("backend", "allowhttp", "true") config.AddOption("backend", "secret", string(testBackendSecret)) - cfg, err := NewBackendConfiguration(config, nil) + cfg, err := NewBackendConfiguration(logger, config, nil) require.NoError(t, err) testUrls(t, cfg, valid_urls, invalid_urls) } func TestIsUrlAllowed_CompatForceHttps(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) // Old-style configuration, force HTTPS valid_urls := []string{ "https://domain.invalid", @@ -116,13 +116,13 @@ func TestIsUrlAllowed_CompatForceHttps(t *testing.T) { config := goconf.NewConfigFile() config.AddOption("backend", "allowed", "domain.invalid") config.AddOption("backend", "secret", string(testBackendSecret)) - cfg, err := NewBackendConfiguration(config, nil) + cfg, err := NewBackendConfiguration(logger, config, nil) require.NoError(t, err) testUrls(t, cfg, valid_urls, invalid_urls) } func TestIsUrlAllowed(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) valid_urls := [][]string{ {"https://domain.invalid/foo", string(testBackendSecret) + "-foo"}, {"https://domain.invalid/foo/", string(testBackendSecret) + "-foo"}, @@ -160,13 +160,13 @@ func TestIsUrlAllowed(t *testing.T) { config.AddOption("baz", "secret", string(testBackendSecret)+"-baz") config.AddOption("lala", "url", "https://otherdomain.invalid/") config.AddOption("lala", "secret", string(testBackendSecret)+"-lala") - cfg, err := NewBackendConfiguration(config, nil) + cfg, err := NewBackendConfiguration(logger, config, nil) require.NoError(t, err) testBackends(t, cfg, valid_urls, invalid_urls) } func TestIsUrlAllowed_EmptyAllowlist(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) valid_urls := []string{} invalid_urls := []string{ "http://domain.invalid", @@ -176,13 +176,13 @@ func TestIsUrlAllowed_EmptyAllowlist(t *testing.T) { config := goconf.NewConfigFile() config.AddOption("backend", "allowed", "") config.AddOption("backend", "secret", string(testBackendSecret)) - cfg, err := NewBackendConfiguration(config, nil) + cfg, err := NewBackendConfiguration(logger, config, nil) require.NoError(t, err) testUrls(t, cfg, valid_urls, invalid_urls) } func TestIsUrlAllowed_AllowAll(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) valid_urls := []string{ "http://domain.invalid", "https://domain.invalid", @@ -195,7 +195,7 @@ func TestIsUrlAllowed_AllowAll(t *testing.T) { config.AddOption("backend", "allowall", "true") config.AddOption("backend", "allowed", "") config.AddOption("backend", "secret", string(testBackendSecret)) - cfg, err := NewBackendConfiguration(config, nil) + cfg, err := NewBackendConfiguration(logger, config, nil) require.NoError(t, err) testUrls(t, cfg, valid_urls, invalid_urls) } @@ -206,7 +206,6 @@ type ParseBackendIdsTestcase struct { } func TestParseBackendIds(t *testing.T) { - CatchLogForTest(t) testcases := []ParseBackendIdsTestcase{ {"", nil}, {"backend1", []string{"backend1"}}, @@ -227,7 +226,7 @@ func TestParseBackendIds(t *testing.T) { func TestBackendReloadNoChange(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -236,7 +235,7 @@ func TestBackendReloadNoChange(t *testing.T) { original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - o_cfg, err := NewBackendConfiguration(original_config, nil) + o_cfg, err := NewBackendConfiguration(logger, original_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 2) @@ -247,7 +246,7 @@ func TestBackendReloadNoChange(t *testing.T) { new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - n_cfg, err := NewBackendConfiguration(new_config, nil) + n_cfg, err := NewBackendConfiguration(logger, new_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 4) @@ -261,7 +260,7 @@ func TestBackendReloadNoChange(t *testing.T) { func TestBackendReloadChangeExistingURL(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -270,7 +269,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - o_cfg, err := NewBackendConfiguration(original_config, nil) + o_cfg, err := NewBackendConfiguration(logger, original_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 2) @@ -282,7 +281,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { new_config.AddOption("backend1", "sessionlimit", "10") new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - n_cfg, err := NewBackendConfiguration(new_config, nil) + n_cfg, err := NewBackendConfiguration(logger, new_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 4) @@ -300,7 +299,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { func TestBackendReloadChangeSecret(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -309,7 +308,7 @@ func TestBackendReloadChangeSecret(t *testing.T) { original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - o_cfg, err := NewBackendConfiguration(original_config, nil) + o_cfg, err := NewBackendConfiguration(logger, original_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 2) @@ -320,7 +319,7 @@ func TestBackendReloadChangeSecret(t *testing.T) { new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend3") new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - n_cfg, err := NewBackendConfiguration(new_config, nil) + n_cfg, err := NewBackendConfiguration(logger, new_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 4) @@ -335,14 +334,14 @@ func TestBackendReloadChangeSecret(t *testing.T) { func TestBackendReloadAddBackend(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1") original_config.AddOption("backend", "allowall", "false") original_config.AddOption("backend1", "url", "http://domain1.invalid") original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") - o_cfg, err := NewBackendConfiguration(original_config, nil) + o_cfg, err := NewBackendConfiguration(logger, original_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 1) @@ -354,7 +353,7 @@ func TestBackendReloadAddBackend(t *testing.T) { new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") new_config.AddOption("backend2", "sessionlimit", "10") - n_cfg, err := NewBackendConfiguration(new_config, nil) + n_cfg, err := NewBackendConfiguration(logger, new_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 3) @@ -374,7 +373,7 @@ func TestBackendReloadAddBackend(t *testing.T) { func TestBackendReloadRemoveHost(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -383,7 +382,7 @@ func TestBackendReloadRemoveHost(t *testing.T) { original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - o_cfg, err := NewBackendConfiguration(original_config, nil) + o_cfg, err := NewBackendConfiguration(logger, original_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 2) @@ -392,7 +391,7 @@ func TestBackendReloadRemoveHost(t *testing.T) { new_config.AddOption("backend", "allowall", "false") new_config.AddOption("backend1", "url", "http://domain1.invalid") new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") - n_cfg, err := NewBackendConfiguration(new_config, nil) + n_cfg, err := NewBackendConfiguration(logger, new_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 3) @@ -410,7 +409,7 @@ func TestBackendReloadRemoveHost(t *testing.T) { func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) original_config := goconf.NewConfigFile() original_config.AddOption("backend", "backends", "backend1, backend2") @@ -419,7 +418,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") original_config.AddOption("backend2", "url", "http://domain1.invalid/bar/") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - o_cfg, err := NewBackendConfiguration(original_config, nil) + o_cfg, err := NewBackendConfiguration(logger, original_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 2) @@ -428,7 +427,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) { new_config.AddOption("backend", "allowall", "false") new_config.AddOption("backend1", "url", "http://domain1.invalid/foo/") new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") - n_cfg, err := NewBackendConfiguration(new_config, nil) + n_cfg, err := NewBackendConfiguration(logger, new_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 3) @@ -462,7 +461,7 @@ func mustParse(s string) *url.URL { func TestBackendConfiguration_EtcdCompat(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) assert := assert.New(t) etcd, client := NewEtcdClientForTest(t) @@ -479,7 +478,7 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { checkStatsValue(t, statsBackendsCurrent, 0) - cfg, err := NewBackendConfiguration(config, client) + cfg, err := NewBackendConfiguration(logger, config, client) require.NoError(err) defer cfg.Close() @@ -581,7 +580,7 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { func TestBackendCommonSecret(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) assert := assert.New(t) u1, err := url.Parse("http://domain1.invalid") @@ -594,7 +593,7 @@ func TestBackendCommonSecret(t *testing.T) { original_config.AddOption("backend1", "url", u1.String()) original_config.AddOption("backend2", "url", u2.String()) original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") - cfg, err := NewBackendConfiguration(original_config, nil) + cfg, err := NewBackendConfiguration(logger, original_config, nil) require.NoError(err) if b1 := cfg.GetBackend(u1); assert.NotNil(b1) { @@ -623,7 +622,7 @@ func TestBackendCommonSecret(t *testing.T) { func TestBackendChangeUrls(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) assert := assert.New(t) u1, err := url.Parse("http://domain1.invalid/") @@ -638,7 +637,7 @@ func TestBackendChangeUrls(t *testing.T) { checkStatsValue(t, statsBackendsCurrent, 0) - cfg, err := NewBackendConfiguration(original_config, nil) + cfg, err := NewBackendConfiguration(logger, original_config, nil) require.NoError(err) checkStatsValue(t, statsBackendsCurrent, 2) @@ -714,7 +713,7 @@ func TestBackendChangeUrls(t *testing.T) { func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { ResetStatsValue(t, statsBackendsCurrent) - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) assert := assert.New(t) etcd, client := NewEtcdClientForTest(t) @@ -731,7 +730,7 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { checkStatsValue(t, statsBackendsCurrent, 0) - cfg, err := NewBackendConfiguration(config, client) + cfg, err := NewBackendConfiguration(logger, config, client) require.NoError(err) defer cfg.Close() diff --git a/backend_server.go b/backend_server.go index ed6ee3c..23975be 100644 --- a/backend_server.go +++ b/backend_server.go @@ -31,7 +31,6 @@ import ( "errors" "fmt" "io" - "log" "net" "net/http" "net/http/pprof" @@ -63,6 +62,7 @@ const ( ) type BackendServer struct { + logger Logger hub *Hub events AsyncEvents roomSessions RoomSessions @@ -82,7 +82,8 @@ type BackendServer struct { buffers BufferPool } -func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*BackendServer, error) { +func NewBackendServer(ctx context.Context, config *goconf.ConfigFile, hub *Hub, version string) (*BackendServer, error) { + logger := LoggerFromContext(ctx) turnapikey, _ := GetStringOptionWithEnv(config, "turn", "apikey") turnsecret, _ := GetStringOptionWithEnv(config, "turn", "secret") turnservers, _ := config.GetString("turn", "servers") @@ -98,10 +99,10 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac return nil, fmt.Errorf("need a shared TURN secret if TURN servers are configured") } - log.Printf("Using configured TURN API key") - log.Printf("Using configured shared TURN secret") + logger.Printf("Using configured TURN API key") + logger.Printf("Using configured shared TURN secret") for _, s := range turnserverslist { - log.Printf("Adding \"%s\" as TURN server", s) + logger.Printf("Adding \"%s\" as TURN server", s) } } @@ -112,10 +113,10 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac } if !statsAllowedIps.Empty() { - log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) + logger.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) } else { statsAllowedIps = DefaultAllowedIps() - log.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) + logger.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) } invalidSecret := make([]byte, 32) @@ -126,6 +127,7 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac debug, _ := config.GetBool("app", "debug") result := &BackendServer{ + logger: logger, hub: hub, events: hub.events, roomSessions: hub.roomSessions, @@ -149,14 +151,14 @@ func (b *BackendServer) Reload(config *goconf.ConfigFile) { statsAllowed, _ := config.GetString("stats", "allowed_ips") if statsAllowedIps, err := ParseAllowedIps(statsAllowed); err == nil { if !statsAllowedIps.Empty() { - log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) + b.logger.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) } else { statsAllowedIps = DefaultAllowedIps() - log.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) + b.logger.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) } b.statsAllowedIps.Store(statsAllowedIps) } else { - log.Printf("Error parsing allowed stats ips from \"%s\": %s", statsAllowedIps, err) + b.logger.Printf("Error parsing allowed stats ips from \"%s\": %s", statsAllowedIps, err) } } @@ -174,7 +176,7 @@ func (b *BackendServer) Start(r *mux.Router) error { b.welcomeMessage = string(welcomeMessage) + "\n" if b.debug { - log.Println("Installing debug handlers in \"/debug/pprof\"") + b.logger.Println("Installing debug handlers in \"/debug/pprof\"") s := r.PathPrefix("/debug/pprof").Subrouter() s.HandleFunc("", b.setCommonHeaders(b.validateStatsRequest(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/debug/pprof/", http.StatusTemporaryRedirect) @@ -273,7 +275,7 @@ func (b *BackendServer) getTurnCredentials(w http.ResponseWriter, r *http.Reques data, err := json.Marshal(result) if err != nil { - log.Printf("Could not serialize TURN credentials: %s", err) + b.logger.Printf("Could not serialize TURN credentials: %s", err) w.WriteHeader(http.StatusInternalServerError) io.WriteString(w, "Could not serialize credentials.") // nolint return @@ -288,7 +290,7 @@ func (b *BackendServer) getTurnCredentials(w http.ResponseWriter, r *http.Reques w.Write(data) // nolint } -func (b *BackendServer) parseRequestBody(f func(http.ResponseWriter, *http.Request, []byte)) func(http.ResponseWriter, *http.Request) { +func (b *BackendServer) parseRequestBody(f func(context.Context, http.ResponseWriter, *http.Request, []byte)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { // Sanity checks if r.ContentLength == -1 { @@ -300,7 +302,7 @@ func (b *BackendServer) parseRequestBody(f func(http.ResponseWriter, *http.Reque } ct := r.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { - log.Printf("Received unsupported content-type: %s", ct) + b.logger.Printf("Received unsupported content-type: %s", ct) http.Error(w, "Unsupported Content-Type", http.StatusBadRequest) return } @@ -313,13 +315,14 @@ func (b *BackendServer) parseRequestBody(f func(http.ResponseWriter, *http.Reque body, err := b.buffers.ReadAll(r.Body) if err != nil { - log.Println("Error reading body: ", err) + b.logger.Println("Error reading body: ", err) http.Error(w, "Could not read body", http.StatusBadRequest) return } defer b.buffers.Put(body) - f(w, r, body.Bytes()) + ctx := NewLoggerContext(r.Context(), b.logger) + f(ctx, w, r, body.Bytes()) } } @@ -340,7 +343,7 @@ func (b *BackendServer) sendRoomInvite(roomid string, backend *Backend, userids } for _, userid := range userids { if err := b.events.PublishUserMessage(userid, backend, msg); err != nil { - log.Printf("Could not publish room invite for user %s in backend %s: %s", userid, backend.Id(), err) + b.logger.Printf("Could not publish room invite for user %s in backend %s: %s", userid, backend.Id(), err) } } } @@ -364,12 +367,13 @@ func (b *BackendServer) sendRoomDisinvite(roomid string, backend *Backend, reaso } for _, userid := range userids { if err := b.events.PublishUserMessage(userid, backend, msg); err != nil { - log.Printf("Could not publish room disinvite for user %s in backend %s: %s", userid, backend.Id(), err) + b.logger.Printf("Could not publish room disinvite for user %s in backend %s: %s", userid, backend.Id(), err) } } timeout := time.Second - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx := NewLoggerContext(context.Background(), b.logger) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() var wg sync.WaitGroup for _, sessionid := range sessionids { @@ -382,10 +386,10 @@ func (b *BackendServer) sendRoomDisinvite(roomid string, backend *Backend, reaso go func(sessionid RoomSessionId) { defer wg.Done() if sid, err := b.lookupByRoomSessionId(ctx, sessionid, nil); err != nil { - log.Printf("Could not lookup by room session %s: %s", sessionid, err) + b.logger.Printf("Could not lookup by room session %s: %s", sessionid, err) } else if sid != "" { if err := b.events.PublishSessionMessage(sid, backend, msg); err != nil { - log.Printf("Could not publish room disinvite for session %s: %s", sid, err) + b.logger.Printf("Could not publish room disinvite for session %s: %s", sid, err) } } }(sessionid) @@ -419,14 +423,14 @@ func (b *BackendServer) sendRoomUpdate(roomid string, backend *Backend, notified } if err := b.events.PublishUserMessage(userid, backend, msg); err != nil { - log.Printf("Could not publish room update for user %s in backend %s: %s", userid, backend.Id(), err) + b.logger.Printf("Could not publish room update for user %s in backend %s: %s", userid, backend.Id(), err) } } } func (b *BackendServer) lookupByRoomSessionId(ctx context.Context, roomSessionId RoomSessionId, cache *ConcurrentMap[RoomSessionId, PublicSessionId]) (PublicSessionId, error) { if roomSessionId == sessionIdNotInMeeting { - log.Printf("Trying to lookup empty room session id: %s", roomSessionId) + b.logger.Printf("Trying to lookup empty room session id: %s", roomSessionId) return "", nil } @@ -458,13 +462,13 @@ func (b *BackendServer) fixupUserSessions(ctx context.Context, cache *Concurrent for _, user := range users { roomSessionId, found := api.GetStringMapString[RoomSessionId](user, "sessionId") if !found { - log.Printf("User %+v has invalid room session id, ignoring", user) + b.logger.Printf("User %+v has invalid room session id, ignoring", user) delete(user, "sessionId") continue } if roomSessionId == sessionIdNotInMeeting { - log.Printf("User %+v is not in the meeting, ignoring", user) + b.logger.Printf("User %+v is not in the meeting, ignoring", user) delete(user, "sessionId") continue } @@ -473,7 +477,7 @@ func (b *BackendServer) fixupUserSessions(ctx context.Context, cache *Concurrent go func(roomSessionId RoomSessionId, u api.StringMap) { defer wg.Done() if sessionId, err := b.lookupByRoomSessionId(ctx, roomSessionId, cache); err != nil { - log.Printf("Could not lookup by room session %s: %s", roomSessionId, err) + b.logger.Printf("Could not lookup by room session %s: %s", roomSessionId, err) delete(u, "sessionId") } else if sessionId != "" { u["sessionId"] = sessionId @@ -498,7 +502,8 @@ func (b *BackendServer) sendRoomIncall(roomid string, backend *Backend, request if !request.InCall.All { timeout := time.Second - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx := NewLoggerContext(context.Background(), b.logger) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() var cache ConcurrentMap[RoomSessionId, PublicSessionId] // Convert (Nextcloud) session ids to signaling session ids. @@ -518,11 +523,11 @@ func (b *BackendServer) sendRoomIncall(roomid string, backend *Backend, request return b.events.PublishBackendRoomMessage(roomid, backend, message) } -func (b *BackendServer) sendRoomParticipantsUpdate(roomid string, backend *Backend, request *BackendServerRoomRequest) error { +func (b *BackendServer) sendRoomParticipantsUpdate(ctx context.Context, roomid string, backend *Backend, request *BackendServerRoomRequest) error { timeout := time.Second // Convert (Nextcloud) session ids to signaling session ids. - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() var cache ConcurrentMap[RoomSessionId, PublicSessionId] request.Participants.Users = b.fixupUserSessions(ctx, &cache, request.Participants.Users) @@ -542,20 +547,20 @@ loop: sessionId, found := api.GetStringMapString[PublicSessionId](user, "sessionId") if !found { - log.Printf("User entry has no session id: %+v", user) + b.logger.Printf("User entry has no session id: %+v", user) continue } permissionsList, ok := permissionsInterface.([]any) if !ok { - log.Printf("Received invalid permissions %+v (%s) for session %s", permissionsInterface, reflect.TypeOf(permissionsInterface), sessionId) + b.logger.Printf("Received invalid permissions %+v (%s) for session %s", permissionsInterface, reflect.TypeOf(permissionsInterface), sessionId) continue } var permissions []Permission for idx, ob := range permissionsList { permission, ok := ob.(string) if !ok { - log.Printf("Received invalid permission at position %d %+v (%s) for session %s", idx, ob, reflect.TypeOf(ob), sessionId) + b.logger.Printf("Received invalid permission at position %d %+v (%s) for session %s", idx, ob, reflect.TypeOf(ob), sessionId) continue loop } permissions = append(permissions, Permission(permission)) @@ -569,7 +574,7 @@ loop: Permissions: permissions, } if err := b.events.PublishSessionMessage(sessionId, backend, message); err != nil { - log.Printf("Could not send permissions update (%+v) to session %s: %s", permissions, sessionId, err) + b.logger.Printf("Could not send permissions update (%+v) to session %s: %s", permissions, sessionId, err) } }(sessionId, permissions) } @@ -590,11 +595,11 @@ func (b *BackendServer) sendRoomMessage(roomid string, backend *Backend, request return b.events.PublishBackendRoomMessage(roomid, backend, message) } -func (b *BackendServer) sendRoomSwitchTo(roomid string, backend *Backend, request *BackendServerRoomRequest) error { +func (b *BackendServer) sendRoomSwitchTo(ctx context.Context, roomid string, backend *Backend, request *BackendServerRoomRequest) error { timeout := time.Second // Convert (Nextcloud) session ids to signaling session ids. - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() var wg sync.WaitGroup @@ -621,7 +626,7 @@ func (b *BackendServer) sendRoomSwitchTo(roomid string, backend *Backend, reques go func(roomSessionId RoomSessionId) { defer wg.Done() if sessionId, err := b.lookupByRoomSessionId(ctx, roomSessionId, nil); err != nil { - log.Printf("Could not lookup by room session %s: %s", roomSessionId, err) + b.logger.Printf("Could not lookup by room session %s: %s", roomSessionId, err) } else if sessionId != "" { mu.Lock() defer mu.Unlock() @@ -659,7 +664,7 @@ func (b *BackendServer) sendRoomSwitchTo(roomid string, backend *Backend, reques go func(roomSessionId RoomSessionId, details json.RawMessage) { defer wg.Done() if sessionId, err := b.lookupByRoomSessionId(ctx, roomSessionId, nil); err != nil { - log.Printf("Could not lookup by room session %s: %s", roomSessionId, err) + b.logger.Printf("Could not lookup by room session %s: %s", roomSessionId, err) } else if sessionId != "" { mu.Lock() defer mu.Unlock() @@ -819,7 +824,7 @@ func (b *BackendServer) startDialout(ctx context.Context, roomid string, backend response, err := b.startDialoutInSession(ctx, session, roomid, backend, backendUrl, request) if err != nil { - log.Printf("Error starting dialout request %+v in session %s: %+v", request.Dialout, session.PublicId(), err) + b.logger.Printf("Error starting dialout request %+v in session %s: %+v", request.Dialout, session.PublicId(), err) var e *Error if sessionError == nil && errors.As(err, &e) { sessionError = e @@ -837,13 +842,13 @@ func (b *BackendServer) startDialout(ctx context.Context, roomid string, backend return returnDialoutError(http.StatusNotFound, NewError("no_client_available", "No available client found to trigger dialout.")) } -func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body []byte) { - throttle, err := b.hub.throttler.CheckBruteforce(r.Context(), b.hub.getRealUserIP(r), "BackendRoomAuth") +func (b *BackendServer) roomHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, body []byte) { + throttle, err := b.hub.throttler.CheckBruteforce(ctx, b.hub.getRealUserIP(r), "BackendRoomAuth") if err == ErrBruteforceDetected { http.Error(w, "Too many requests", http.StatusTooManyRequests) return } else if err != nil { - log.Printf("Error checking for bruteforce: %s", err) + b.logger.Printf("Error checking for bruteforce: %s", err) http.Error(w, "Could not check for bruteforce", http.StatusInternalServerError) return } @@ -860,7 +865,7 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body if backend == nil { // Unknown backend URL passed, return immediately. - throttle(r.Context()) + throttle(ctx) http.Error(w, "Authentication check failed", http.StatusForbidden) return } @@ -882,21 +887,21 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body } if backend == nil { - throttle(r.Context()) + throttle(ctx) http.Error(w, "Authentication check failed", http.StatusForbidden) return } } if !ValidateBackendChecksum(r, body, backend.Secret()) { - throttle(r.Context()) + throttle(ctx) http.Error(w, "Authentication check failed", http.StatusForbidden) return } var request BackendServerRoomRequest if err := json.Unmarshal(body, &request); err != nil { - log.Printf("Error decoding body %s: %s", string(body), err) + b.logger.Printf("Error decoding body %s: %s", string(body), err) http.Error(w, "Could not read body", http.StatusBadRequest) return } @@ -928,20 +933,20 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body case "incall": err = b.sendRoomIncall(roomid, backend, &request) case "participants": - err = b.sendRoomParticipantsUpdate(roomid, backend, &request) + err = b.sendRoomParticipantsUpdate(ctx, roomid, backend, &request) case "message": err = b.sendRoomMessage(roomid, backend, &request) case "switchto": - err = b.sendRoomSwitchTo(roomid, backend, &request) + err = b.sendRoomSwitchTo(ctx, roomid, backend, &request) case "dialout": - response, err = b.startDialout(r.Context(), roomid, backend, backendUrl, &request) + response, err = b.startDialout(ctx, roomid, backend, backendUrl, &request) default: http.Error(w, "Unsupported request type: "+request.Type, http.StatusBadRequest) return } if err != nil { - log.Printf("Error processing %s for room %s: %s", string(body), roomid, err) + b.logger.Printf("Error processing %s for room %s: %s", string(body), roomid, err) http.Error(w, "Error while processing", http.StatusInternalServerError) return } @@ -957,7 +962,7 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body } responseData, err = json.Marshal(response) if err != nil { - log.Printf("Could not serialize backend response %+v: %s", response, err) + b.logger.Printf("Could not serialize backend response %+v: %s", response, err) responseStatus = http.StatusInternalServerError responseData = []byte("{\"error\":\"could_not_serialize\"}") } @@ -995,7 +1000,7 @@ func (b *BackendServer) statsHandler(w http.ResponseWriter, r *http.Request) { stats := b.hub.GetStats() statsData, err := json.MarshalIndent(stats, "", " ") if err != nil { - log.Printf("Could not serialize stats %+v: %s", stats, err) + b.logger.Printf("Could not serialize stats %+v: %s", stats, err) http.Error(w, "Internal server error", http.StatusInternalServerError) return } @@ -1028,7 +1033,7 @@ func (b *BackendServer) serverinfoHandler(w http.ResponseWriter, r *http.Request infoData, err := json.MarshalIndent(info, "", " ") if err != nil { - log.Printf("Could not serialize server info %+v: %s", info, err) + b.logger.Printf("Could not serialize server info %+v: %s", info, err) http.Error(w, "Internal server error", http.StatusInternalServerError) return } diff --git a/backend_server_test.go b/backend_server_test.go index 5d9eb95..460bb7d 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -99,9 +99,11 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil config.AddOption("clients", "internalsecret", string(testInternalSecret)) config.AddOption("geoip", "url", "none") events := getAsyncEventsForTest(t) - hub, err := NewHub(config, events, nil, nil, nil, r, "no-version") + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + hub, err := NewHub(ctx, config, events, nil, nil, nil, r, "no-version") require.NoError(err) - b, err := NewBackendServer(config, hub, "no-version") + b, err := NewBackendServer(ctx, config, hub, "no-version") require.NoError(err) require.NoError(b.Start(r)) @@ -158,13 +160,16 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g config1.AddOption("clients", "internalsecret", string(testInternalSecret)) config1.AddOption("geoip", "url", "none") - events1, err := NewAsyncEvents(nats.ClientURL()) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + + events1, err := NewAsyncEvents(ctx, nats.ClientURL()) require.NoError(err) t.Cleanup(func() { events1.Close() }) client1, _ := NewGrpcClientsForTest(t, addr2) - hub1, err := NewHub(config1, events1, grpcServer1, client1, nil, r1, "no-version") + hub1, err := NewHub(ctx, config1, events1, grpcServer1, client1, nil, r1, "no-version") require.NoError(err) if config2 == nil { @@ -181,19 +186,19 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g config2.AddOption("sessions", "blockkey", "09876543210987654321098765432109") config2.AddOption("clients", "internalsecret", string(testInternalSecret)) config2.AddOption("geoip", "url", "none") - events2, err := NewAsyncEvents(nats.ClientURL()) + events2, err := NewAsyncEvents(ctx, nats.ClientURL()) require.NoError(err) t.Cleanup(func() { events2.Close() }) client2, _ := NewGrpcClientsForTest(t, addr1) - hub2, err := NewHub(config2, events2, grpcServer2, client2, nil, r2, "no-version") + hub2, err := NewHub(ctx, config2, events2, grpcServer2, client2, nil, r2, "no-version") require.NoError(err) - b1, err := NewBackendServer(config1, hub1, "no-version") + b1, err := NewBackendServer(ctx, config1, hub1, "no-version") require.NoError(err) require.NoError(b1.Start(r1)) - b2, err := NewBackendServer(config2, hub2, "no-version") + b2, err := NewBackendServer(ctx, config2, hub2, "no-version") require.NoError(err) require.NoError(b2.Start(r2)) @@ -258,7 +263,6 @@ func expectRoomlistEvent(t *testing.T, ch chan *AsyncMessage, msgType string) (* func TestBackendServer_NoAuth(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) @@ -280,7 +284,6 @@ func TestBackendServer_NoAuth(t *testing.T) { func TestBackendServer_InvalidAuth(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) @@ -304,7 +307,6 @@ func TestBackendServer_InvalidAuth(t *testing.T) { func TestBackendServer_OldCompatAuth(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) @@ -347,7 +349,6 @@ func TestBackendServer_OldCompatAuth(t *testing.T) { func TestBackendServer_InvalidBody(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) @@ -364,7 +365,6 @@ func TestBackendServer_InvalidBody(t *testing.T) { func TestBackendServer_UnsupportedRequest(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTest(t) @@ -385,11 +385,12 @@ func TestBackendServer_UnsupportedRequest(t *testing.T) { } func TestBackendServer_RoomInvite(t *testing.T) { - CatchLogForTest(t) for _, backend := range eventBackendsForTest { t.Run(backend, func(t *testing.T) { t.Parallel() - RunTestBackendServer_RoomInvite(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + RunTestBackendServer_RoomInvite(ctx, t) }) } } @@ -402,7 +403,7 @@ func (l *channelEventListener) ProcessAsyncUserMessage(message *AsyncMessage) { l.ch <- message } -func RunTestBackendServer_RoomInvite(t *testing.T) { +func RunTestBackendServer_RoomInvite(ctx context.Context, t *testing.T) { require := require.New(t) assert := assert.New(t) _, _, events, hub, _, server := CreateBackendServerForTest(t) @@ -451,16 +452,17 @@ func RunTestBackendServer_RoomInvite(t *testing.T) { } func TestBackendServer_RoomDisinvite(t *testing.T) { - CatchLogForTest(t) for _, backend := range eventBackendsForTest { t.Run(backend, func(t *testing.T) { t.Parallel() - RunTestBackendServer_RoomDisinvite(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + RunTestBackendServer_RoomDisinvite(ctx, t) }) } } -func RunTestBackendServer_RoomDisinvite(t *testing.T) { +func RunTestBackendServer_RoomDisinvite(ctx context.Context, t *testing.T) { require := require.New(t) assert := assert.New(t) _, _, events, hub, _, server := CreateBackendServerForTest(t) @@ -470,7 +472,7 @@ func RunTestBackendServer_RoomDisinvite(t *testing.T) { backend := hub.backend.GetBackend(u) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId) @@ -531,12 +533,13 @@ func RunTestBackendServer_RoomDisinvite(t *testing.T) { func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId) @@ -607,16 +610,17 @@ func TestBackendServer_RoomDisinviteDifferentRooms(t *testing.T) { } func TestBackendServer_RoomUpdate(t *testing.T) { - CatchLogForTest(t) for _, backend := range eventBackendsForTest { t.Run(backend, func(t *testing.T) { t.Parallel() - RunTestBackendServer_RoomUpdate(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + RunTestBackendServer_RoomUpdate(ctx, t) }) } } -func RunTestBackendServer_RoomUpdate(t *testing.T) { +func RunTestBackendServer_RoomUpdate(ctx context.Context, t *testing.T) { require := require.New(t) assert := assert.New(t) _, _, events, hub, _, server := CreateBackendServerForTest(t) @@ -675,16 +679,17 @@ func RunTestBackendServer_RoomUpdate(t *testing.T) { } func TestBackendServer_RoomDelete(t *testing.T) { - CatchLogForTest(t) for _, backend := range eventBackendsForTest { t.Run(backend, func(t *testing.T) { t.Parallel() - RunTestBackendServer_RoomDelete(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + RunTestBackendServer_RoomDelete(ctx, t) }) } } -func RunTestBackendServer_RoomDelete(t *testing.T) { +func RunTestBackendServer_RoomDelete(ctx context.Context, t *testing.T) { require := require.New(t) assert := assert.New(t) _, _, events, hub, _, server := CreateBackendServerForTest(t) @@ -740,10 +745,11 @@ func RunTestBackendServer_RoomDelete(t *testing.T) { } func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) var hub1 *Hub @@ -760,7 +766,7 @@ func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { _, _, hub1, hub2, server1, server2 = CreateBackendServerWithClusteringForTest(t) } - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client1, hello1 := NewTestClientWithHello(ctx, t, server1, hub1, testDefaultUserId+"1") @@ -837,12 +843,13 @@ func TestBackendServer_ParticipantsUpdatePermissions(t *testing.T) { func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId) @@ -900,12 +907,13 @@ func TestBackendServer_ParticipantsUpdateEmptyPermissions(t *testing.T) { func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") @@ -1056,10 +1064,11 @@ func TestBackendServer_ParticipantsUpdateTimeout(t *testing.T) { } func TestBackendServer_InCallAll(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) var hub1 *Hub @@ -1076,7 +1085,7 @@ func TestBackendServer_InCallAll(t *testing.T) { _, _, hub1, hub2, server1, server2 = CreateBackendServerWithClusteringForTest(t) } - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client1, hello1 := NewTestClientWithHello(ctx, t, server1, hub1, testDefaultUserId+"1") @@ -1228,12 +1237,13 @@ func TestBackendServer_InCallAll(t *testing.T) { func TestBackendServer_RoomMessage(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client, _ := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") @@ -1271,7 +1281,6 @@ func TestBackendServer_RoomMessage(t *testing.T) { func TestBackendServer_TurnCredentials(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) _, _, _, _, _, server := CreateBackendServerForTestWithTurn(t) @@ -1301,7 +1310,6 @@ func TestBackendServer_TurnCredentials(t *testing.T) { } func TestBackendServer_StatsAllowedIps(t *testing.T) { - CatchLogForTest(t) config := goconf.NewConfigFile() config.AddOption("app", "trustedproxies", "1.2.3.4") config.AddOption("stats", "allowed_ips", "127.0.0.1, 192.168.0.1, 192.168.1.1/24") @@ -1397,7 +1405,8 @@ func Test_IsNumeric(t *testing.T) { func TestBackendServer_DialoutNoSipBridge(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) @@ -1406,7 +1415,7 @@ func TestBackendServer_DialoutNoSipBridge(t *testing.T) { defer client.CloseWithBye() require.NoError(client.SendHelloInternal()) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() MustSucceed1(t, client.RunUntilHello, ctx) @@ -1440,7 +1449,8 @@ func TestBackendServer_DialoutNoSipBridge(t *testing.T) { func TestBackendServer_DialoutAccepted(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) @@ -1449,7 +1459,7 @@ func TestBackendServer_DialoutAccepted(t *testing.T) { defer client.CloseWithBye() require.NoError(client.SendHelloInternalWithFeatures([]string{"start-dialout"})) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() MustSucceed1(t, client.RunUntilHello, ctx) @@ -1526,7 +1536,8 @@ func TestBackendServer_DialoutAccepted(t *testing.T) { func TestBackendServer_DialoutAcceptedCompat(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) @@ -1535,7 +1546,7 @@ func TestBackendServer_DialoutAcceptedCompat(t *testing.T) { defer client.CloseWithBye() require.NoError(client.SendHelloInternalWithFeatures([]string{"start-dialout"})) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() MustSucceed1(t, client.RunUntilHello, ctx) @@ -1612,7 +1623,8 @@ func TestBackendServer_DialoutAcceptedCompat(t *testing.T) { func TestBackendServer_DialoutRejected(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) @@ -1621,7 +1633,7 @@ func TestBackendServer_DialoutRejected(t *testing.T) { defer client.CloseWithBye() require.NoError(client.SendHelloInternalWithFeatures([]string{"start-dialout"})) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() MustSucceed1(t, client.RunUntilHello, ctx) @@ -1696,7 +1708,8 @@ func TestBackendServer_DialoutRejected(t *testing.T) { func TestBackendServer_DialoutFirstFailed(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) @@ -1709,7 +1722,7 @@ func TestBackendServer_DialoutFirstFailed(t *testing.T) { defer client2.CloseWithBye() require.NoError(client2.SendHelloInternalWithFeatures([]string{"start-dialout"})) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() MustSucceed1(t, client1.RunUntilHello, ctx) diff --git a/backend_storage_etcd.go b/backend_storage_etcd.go index 654e725..769cb0b 100644 --- a/backend_storage_etcd.go +++ b/backend_storage_etcd.go @@ -26,7 +26,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "net/url" "slices" "time" @@ -38,6 +37,7 @@ import ( type backendStorageEtcd struct { backendStorageCommon + logger Logger etcdClient *EtcdClient keyPrefix string keyInfos map[string]*BackendInformationEtcd @@ -50,7 +50,7 @@ type backendStorageEtcd struct { closeFunc context.CancelFunc } -func NewBackendStorageEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient) (BackendStorage, error) { +func NewBackendStorageEtcd(logger Logger, config *goconf.ConfigFile, etcdClient *EtcdClient) (BackendStorage, error) { if etcdClient == nil || !etcdClient.IsConfigured() { return nil, fmt.Errorf("no etcd endpoints configured") } @@ -66,6 +66,7 @@ func NewBackendStorageEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient) (B backendStorageCommon: backendStorageCommon{ backends: make(map[string][]*Backend), }, + logger: logger, etcdClient: etcdClient, keyPrefix: keyPrefix, keyInfos: make(map[string]*BackendInformationEtcd), @@ -120,9 +121,9 @@ func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) { if errors.Is(err, context.Canceled) { return } else if errors.Is(err, context.DeadlineExceeded) { - log.Printf("Timeout getting initial list of backends, retry in %s", backoff.NextWait()) + s.logger.Printf("Timeout getting initial list of backends, retry in %s", backoff.NextWait()) } else { - log.Printf("Could not get initial list of backends, retry in %s: %s", backoff.NextWait(), err) + s.logger.Printf("Could not get initial list of backends, retry in %s: %s", backoff.NextWait(), err) } backoff.Wait(s.closeCtx) @@ -140,7 +141,7 @@ func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) { for s.closeCtx.Err() == nil { var err error if nextRevision, err = client.Watch(s.closeCtx, s.keyPrefix, nextRevision, s, clientv3.WithPrefix()); err != nil { - log.Printf("Error processing watch for %s (%s), retry in %s", s.keyPrefix, err, backoff.NextWait()) + s.logger.Printf("Error processing watch for %s (%s), retry in %s", s.keyPrefix, err, backoff.NextWait()) backoff.Wait(s.closeCtx) continue } @@ -149,7 +150,7 @@ func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) { backoff.Reset() prevRevision = nextRevision } else { - log.Printf("Processing watch for %s interrupted, retry in %s", s.keyPrefix, backoff.NextWait()) + s.logger.Printf("Processing watch for %s interrupted, retry in %s", s.keyPrefix, backoff.NextWait()) backoff.Wait(s.closeCtx) } } @@ -171,11 +172,11 @@ func (s *backendStorageEtcd) getBackends(ctx context.Context, client *EtcdClient func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data []byte, prevValue []byte) { var info BackendInformationEtcd if err := json.Unmarshal(data, &info); err != nil { - log.Printf("Could not decode backend information %s: %s", string(data), err) + s.logger.Printf("Could not decode backend information %s: %s", string(data), err) return } if err := info.CheckValid(); err != nil { - log.Printf("Received invalid backend information %s: %s", string(data), err) + s.logger.Printf("Received invalid backend information %s: %s", string(data), err) return } @@ -205,7 +206,7 @@ func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data entries, found := s.backends[host] if !found { // Simple case, first backend for this host - log.Printf("Added backend %s (from %s)", info.Urls[idx], key) + s.logger.Printf("Added backend %s (from %s)", info.Urls[idx], key) s.backends[host] = []*Backend{backend} added = true continue @@ -215,7 +216,7 @@ func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data replaced := false for idx, entry := range entries { if entry.id == key { - log.Printf("Updated backend %s (from %s)", info.Urls[idx], key) + s.logger.Printf("Updated backend %s (from %s)", info.Urls[idx], key) entries[idx] = backend replaced = true break @@ -224,7 +225,7 @@ func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data if !replaced { // New backend, add to list. - log.Printf("Added backend %s (from %s)", info.Urls[idx], key) + s.logger.Printf("Added backend %s (from %s)", info.Urls[idx], key) s.backends[host] = append(entries, backend) added = true } @@ -256,13 +257,13 @@ func (s *backendStorageEtcd) EtcdKeyDeleted(client *EtcdClient, key string, prev if slices.ContainsFunc(d, func(b *Backend) bool { return slices.Contains(b.urls, u.String()) }) { - log.Printf("Removing backend %s (from %s)", info.Urls[idx], key) + s.logger.Printf("Removing backend %s (from %s)", info.Urls[idx], key) } } continue } - log.Printf("Removing backend %s (from %s)", info.Urls[idx], key) + s.logger.Printf("Removing backend %s (from %s)", info.Urls[idx], key) newEntries := make([]*Backend, 0, len(entries)-1) for _, entry := range entries { if entry.id == key { diff --git a/backend_storage_etcd_test.go b/backend_storage_etcd_test.go index d9bd77c..2c62520 100644 --- a/backend_storage_etcd_test.go +++ b/backend_storage_etcd_test.go @@ -53,7 +53,7 @@ func (tl *testListener) EtcdClientCreated(client *EtcdClient) { } func Test_BackendStorageEtcdNoLeak(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { etcd, client := NewEtcdClientForTest(t) tl := &testListener{ @@ -67,7 +67,7 @@ func Test_BackendStorageEtcdNoLeak(t *testing.T) { config.AddOption("backend", "backendtype", "etcd") config.AddOption("backend", "backendprefix", "/backends") - cfg, err := NewBackendConfiguration(config, client) + cfg, err := NewBackendConfiguration(logger, config, client) require.NoError(t, err) <-tl.closed diff --git a/backend_storage_static.go b/backend_storage_static.go index 601bf4e..dc90dfa 100644 --- a/backend_storage_static.go +++ b/backend_storage_static.go @@ -22,7 +22,6 @@ package signaling import ( - "log" "net/url" "slices" "strings" @@ -35,6 +34,7 @@ import ( type backendStorageStatic struct { backendStorageCommon + logger Logger backendsById map[string]*Backend // Deprecated @@ -43,7 +43,7 @@ type backendStorageStatic struct { compatBackend *Backend } -func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) { +func NewBackendStorageStatic(logger Logger, config *goconf.ConfigFile) (BackendStorage, error) { allowAll, _ := config.GetBool("backend", "allowall") allowHttp, _ := config.GetBool("backend", "allowhttp") commonSecret, _ := GetStringOptionWithEnv(config, "backend", "secret") @@ -56,7 +56,7 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) var compatBackend *Backend numBackends := 0 if allowAll { - log.Println("WARNING: All backend hostnames are allowed, only use for development!") + logger.Println("WARNING: All backend hostnames are allowed, only use for development!") maxStreamBitrate, err := config.GetInt("backend", "maxstreambitrate") if err != nil || maxStreamBitrate < 0 { maxStreamBitrate = 0 @@ -78,21 +78,21 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) maxScreenBitrate: api.BandwidthFromBits(uint64(maxScreenBitrate)), } if sessionLimit > 0 { - log.Printf("Allow a maximum of %d sessions", sessionLimit) + logger.Printf("Allow a maximum of %d sessions", sessionLimit) } updateBackendStats(compatBackend) backendsById[compatBackend.id] = compatBackend numBackends++ } else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" { added := make(map[string]*Backend) - for host, configuredBackends := range getConfiguredHosts(backendIds, config, commonSecret) { + for host, configuredBackends := range getConfiguredHosts(logger, backendIds, config, commonSecret) { backends[host] = append(backends[host], configuredBackends...) for _, be := range configuredBackends { added[be.id] = be } } for _, be := range added { - log.Printf("Backend %s added for %s", be.id, strings.Join(be.urls, ", ")) + logger.Printf("Backend %s added for %s", be.id, strings.Join(be.urls, ", ")) backendsById[be.id] = be updateBackendStats(be) be.counted = true @@ -103,7 +103,7 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) allowMap := make(map[string]bool) for u := range SplitEntries(allowedUrls, ",") { if idx := strings.IndexByte(u, '/'); idx != -1 { - log.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u) + logger.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u) if u = u[:idx]; u == "" { continue } @@ -113,7 +113,7 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) } if len(allowMap) == 0 { - log.Println("WARNING: No backend hostnames are allowed, check your configuration!") + logger.Println("WARNING: No backend hostnames are allowed, check your configuration!") } else { maxStreamBitrate, err := config.GetInt("backend", "maxstreambitrate") if err != nil || maxStreamBitrate < 0 { @@ -141,11 +141,11 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) backends[host] = []*Backend{compatBackend} } if len(hosts) > 1 { - log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.") + logger.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.") } - log.Printf("Allowed backend hostnames: %s", hosts) + logger.Printf("Allowed backend hostnames: %s", hosts) if sessionLimit > 0 { - log.Printf("Allow a maximum of %d sessions", sessionLimit) + logger.Printf("Allow a maximum of %d sessions", sessionLimit) } updateBackendStats(compatBackend) backendsById[compatBackend.id] = compatBackend @@ -154,7 +154,7 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) } if numBackends == 0 { - log.Printf("WARNING: No backends configured, client connections will not be possible.") + logger.Printf("WARNING: No backends configured, client connections will not be possible.") } statsBackendsCurrent.Add(float64(numBackends)) @@ -163,6 +163,7 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) backends: backends, }, + logger: logger, backendsById: backendsById, allowAll: allowAll, @@ -187,7 +188,7 @@ func (s *backendStorageStatic) RemoveBackendsForHost(host string, seen map[strin urls := slices.DeleteFunc(backend.urls, func(s string) bool { return !strings.Contains(s, "://"+host) }) - log.Printf("Backend %s removed for %s", backend.id, strings.Join(urls, ", ")) + s.logger.Printf("Backend %s removed for %s", backend.id, strings.Join(urls, ", ")) if len(urls) == len(backend.urls) && backend.counted { deleteBackendStats(backend) delete(s.backendsById, backend.Id()) @@ -225,7 +226,7 @@ func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend, seen backends = slices.Delete(backends, index, index+1) if seen[newBackend.id] != seenUpdated { seen[newBackend.id] = seenUpdated - log.Printf("Backend %s updated for %s", newBackend.id, strings.Join(newBackend.urls, ", ")) + s.logger.Printf("Backend %s updated for %s", newBackend.id, strings.Join(newBackend.urls, ", ")) updateBackendStats(newBackend) newBackend.counted = existingBackend.counted s.backendsById[newBackend.id] = newBackend @@ -242,7 +243,7 @@ func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend, seen urls := slices.DeleteFunc(removed.urls, func(s string) bool { return !strings.Contains(s, "://"+host) }) - log.Printf("Backend %s removed for %s", removed.id, strings.Join(urls, ", ")) + s.logger.Printf("Backend %s removed for %s", removed.id, strings.Join(urls, ", ")) if len(urls) == len(removed.urls) && removed.counted { deleteBackendStats(removed) delete(s.backendsById, removed.Id()) @@ -268,7 +269,7 @@ func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend, seen s.backendsById[added.id] = added } - log.Printf("Backend %s added for %s", added.id, strings.Join(added.urls, ", ")) + s.logger.Printf("Backend %s added for %s", added.id, strings.Join(added.urls, ", ")) if !added.counted { updateBackendStats(added) addedBackends++ @@ -293,17 +294,17 @@ func getConfiguredBackendIDs(backendIds string) (ids []string) { return ids } -func getConfiguredHosts(backendIds string, config *goconf.ConfigFile, commonSecret string) (hosts map[string][]*Backend) { +func getConfiguredHosts(logger Logger, backendIds string, config *goconf.ConfigFile, commonSecret string) (hosts map[string][]*Backend) { hosts = make(map[string][]*Backend) seenUrls := make(map[string]string) for _, id := range getConfiguredBackendIDs(backendIds) { secret, _ := GetStringOptionWithEnv(config, id, "secret") if secret == "" && commonSecret != "" { - log.Printf("Backend %s has no own shared secret set, using common shared secret", id) + logger.Printf("Backend %s has no own shared secret set, using common shared secret", id) secret = commonSecret } if secret == "" { - log.Printf("Backend %s is missing or incomplete, skipping", id) + logger.Printf("Backend %s is missing or incomplete, skipping", id) continue } @@ -312,7 +313,7 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile, commonSecr sessionLimit = 0 } if sessionLimit > 0 { - log.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit) + logger.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit) } maxStreamBitrate, err := config.GetInt(id, "maxstreambitrate") @@ -335,7 +336,7 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile, commonSecr } if len(urls) == 0 { - log.Printf("Backend %s is missing or incomplete, skipping", id) + logger.Printf("Backend %s is missing or incomplete, skipping", id) continue } @@ -357,7 +358,7 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile, commonSecr parsed, err := url.Parse(u) if err != nil { - log.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err) + logger.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err) continue } @@ -367,7 +368,7 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile, commonSecr } if prev, found := seenUrls[u]; found { - log.Printf("Url %s in backend %s was already used in backend %s, skipping", u, id, prev) + logger.Printf("Url %s in backend %s was already used in backend %s, skipping", u, id, prev) continue } @@ -392,14 +393,14 @@ func (s *backendStorageStatic) Reload(config *goconf.ConfigFile) { defer s.mu.Unlock() if s.compatBackend != nil { - log.Println("Old-style configuration active, reload is not supported") + s.logger.Println("Old-style configuration active, reload is not supported") return } commonSecret, _ := GetStringOptionWithEnv(config, "backend", "secret") if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" { - configuredHosts := getConfiguredHosts(backendIds, config, commonSecret) + configuredHosts := getConfiguredHosts(s.logger, backendIds, config, commonSecret) // remove backends that are no longer configured seen := make(map[string]seenState) diff --git a/capabilities.go b/capabilities.go index c5bb1e0..2f68bb9 100644 --- a/capabilities.go +++ b/capabilities.go @@ -25,7 +25,6 @@ import ( "context" "encoding/json" "errors" - "log" "net/http" "net/url" "strings" @@ -118,6 +117,7 @@ func (e *capabilitiesEntry) errorIfMustRevalidate(err error) (bool, error) { } func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Time) (bool, error) { + logger := LoggerFromContext(ctx) e.mu.Lock() defer e.mu.Unlock() @@ -136,18 +136,18 @@ func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Tim capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities" } - log.Printf("Capabilities expired for %s, updating", capUrl.String()) + logger.Printf("Capabilities expired for %s, updating", capUrl.String()) client, pool, err := e.c.pool.Get(ctx, &capUrl) if err != nil { - log.Printf("Could not get client for host %s: %s", capUrl.Host, err) + logger.Printf("Could not get client for host %s: %s", capUrl.Host, err) return false, err } defer pool.Put(client) req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil) if err != nil { - log.Printf("Could not create request to %s: %s", &capUrl, err) + logger.Printf("Could not create request to %s: %s", &capUrl, err) return false, err } req.Header.Set("Accept", "application/json") @@ -179,22 +179,22 @@ func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Tim e.nextUpdate = now.Add(maxAge) if response.StatusCode == http.StatusNotModified { - log.Printf("Capabilities %+v from %s have not changed", e.capabilities, url) + logger.Printf("Capabilities %+v from %s have not changed", e.capabilities, url) return false, nil } else if response.StatusCode != http.StatusOK { - log.Printf("Received unexpected HTTP status from %s: %s", url, response.Status) + logger.Printf("Received unexpected HTTP status from %s: %s", url, response.Status) return e.errorIfMustRevalidate(ErrUnexpectedHttpStatus) } ct := response.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { - log.Printf("Received unsupported content-type from %s: %s (%s)", url, ct, response.Status) + logger.Printf("Received unsupported content-type from %s: %s (%s)", url, ct, response.Status) return e.errorIfMustRevalidate(ErrUnsupportedContentType) } body, err := e.c.buffers.ReadAll(response.Body) if err != nil { - log.Printf("Could not read response body from %s: %s", url, err) + logger.Printf("Could not read response body from %s: %s", url, err) return e.errorIfMustRevalidate(err) } @@ -202,34 +202,34 @@ func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Tim var ocs OcsResponse if err := json.Unmarshal(body.Bytes(), &ocs); err != nil { - log.Printf("Could not decode OCS response %s from %s: %s", body.String(), url, err) + logger.Printf("Could not decode OCS response %s from %s: %s", body.String(), url, err) return e.errorIfMustRevalidate(err) } else if ocs.Ocs == nil || len(ocs.Ocs.Data) == 0 { - log.Printf("Incomplete OCS response %s from %s", body.String(), url) + logger.Printf("Incomplete OCS response %s from %s", body.String(), url) return e.errorIfMustRevalidate(ErrIncompleteResponse) } var capaResponse CapabilitiesResponse if err := json.Unmarshal(ocs.Ocs.Data, &capaResponse); err != nil { - log.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), url, err) + logger.Printf("Could not decode OCS response body %s from %s: %s", string(ocs.Ocs.Data), url, err) return e.errorIfMustRevalidate(err) } capaObj, found := capaResponse.Capabilities[AppNameSpreed] if !found || len(capaObj) == 0 { - log.Printf("No capabilities received for app spreed from %s: %+v", url, capaResponse) + logger.Printf("No capabilities received for app spreed from %s: %+v", url, capaResponse) e.capabilities = nil return false, nil } var capa api.StringMap if err := json.Unmarshal(capaObj, &capa); err != nil { - log.Printf("Unsupported capabilities received for app spreed from %s: %+v", url, capaResponse) + logger.Printf("Unsupported capabilities received for app spreed from %s: %+v", url, capaResponse) e.capabilities = nil return false, nil } - log.Printf("Received capabilities %+v from %s", capa, url) + logger.Printf("Received capabilities %+v from %s", capa, url) e.capabilities = capa return true, nil } @@ -351,9 +351,10 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (api.St } func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { + logger := LoggerFromContext(ctx) caps, _, err := c.loadCapabilities(ctx, u) if err != nil { - log.Printf("Could not get capabilities for %s: %s", u, err) + logger.Printf("Could not get capabilities for %s: %s", u, err) return false } @@ -364,7 +365,7 @@ func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, fea features, ok := featuresInterface.([]any) if !ok { - log.Printf("Invalid features list received for %s: %+v", u, featuresInterface) + logger.Printf("Invalid features list received for %s: %+v", u, featuresInterface) return false } @@ -377,9 +378,10 @@ func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, fea } func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (api.StringMap, bool, bool) { + logger := LoggerFromContext(ctx) caps, cached, err := c.loadCapabilities(ctx, u) if err != nil { - log.Printf("Could not get capabilities for %s: %s", u, err) + logger.Printf("Could not get capabilities for %s: %s", u, err) return nil, cached, false } @@ -390,7 +392,7 @@ func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group str config, ok := api.ConvertStringMap(configInterface) if !ok { - log.Printf("Invalid config mapping received from %s: %+v", u, configInterface) + logger.Printf("Invalid config mapping received from %s: %+v", u, configInterface) return nil, cached, false } @@ -401,7 +403,7 @@ func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group str groupConfig, ok := api.ConvertStringMap(groupInterface) if !ok { - log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface) + logger.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface) return nil, cached, false } @@ -427,7 +429,8 @@ func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, case float64: return int(value), cached, true default: - log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) + logger := LoggerFromContext(ctx) + logger.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) } return 0, cached, false @@ -448,7 +451,8 @@ func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, k case string: return value, cached, true default: - log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) + logger := LoggerFromContext(ctx) + logger.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) } return "", cached, false diff --git a/capabilities_test.go b/capabilities_test.go index 72a7f00..c4e7a7b 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -174,11 +174,12 @@ func SetCapabilitiesGetNow(t *testing.T, capabilities *Capabilities, f func() ti func TestCapabilities(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) url, capabilities := NewCapabilitiesForTest(t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() assert.True(capabilities.HasCapabilityFeature(ctx, url, "foo")) @@ -217,7 +218,8 @@ func TestCapabilities(t *testing.T) { func TestInvalidateCapabilities(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { @@ -225,7 +227,7 @@ func TestInvalidateCapabilities(t *testing.T) { return nil }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() expectedString := "bar" @@ -277,7 +279,8 @@ func TestInvalidateCapabilities(t *testing.T) { func TestCapabilitiesNoCache(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { @@ -285,7 +288,7 @@ func TestCapabilitiesNoCache(t *testing.T) { return nil }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() expectedString := "bar" @@ -321,7 +324,8 @@ func TestCapabilitiesNoCache(t *testing.T) { func TestCapabilitiesShortCache(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { @@ -329,7 +333,7 @@ func TestCapabilitiesShortCache(t *testing.T) { return nil }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() expectedString := "bar" @@ -375,7 +379,8 @@ func TestCapabilitiesShortCache(t *testing.T) { func TestCapabilitiesNoCacheETag(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { @@ -389,7 +394,7 @@ func TestCapabilitiesNoCacheETag(t *testing.T) { return nil }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() expectedString := "bar" @@ -416,7 +421,8 @@ func TestCapabilitiesNoCacheETag(t *testing.T) { func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { @@ -427,7 +433,7 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) { return nil }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() expectedString := "bar" @@ -456,7 +462,8 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) { func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { @@ -467,7 +474,7 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) { return nil }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() expectedString := "bar" @@ -496,7 +503,8 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) { func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { @@ -507,7 +515,7 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) { return nil }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() expectedString := "bar" @@ -534,7 +542,8 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) { func TestConcurrentExpired(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error { @@ -542,7 +551,7 @@ func TestConcurrentExpired(t *testing.T) { return nil }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() expectedString := "bar" diff --git a/certificate_reloader.go b/certificate_reloader.go index 3e23c96..373b503 100644 --- a/certificate_reloader.go +++ b/certificate_reloader.go @@ -25,12 +25,13 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "log" "os" "sync/atomic" ) type CertificateReloader struct { + logger Logger + certFile string certWatcher *FileWatcher @@ -42,22 +43,23 @@ type CertificateReloader struct { reloadCounter atomic.Uint64 } -func NewCertificateReloader(certFile string, keyFile string) (*CertificateReloader, error) { +func NewCertificateReloader(logger Logger, certFile string, keyFile string) (*CertificateReloader, error) { pair, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, fmt.Errorf("could not load certificate / key: %w", err) } reloader := &CertificateReloader{ + logger: logger, certFile: certFile, keyFile: keyFile, } reloader.certificate.Store(&pair) - reloader.certWatcher, err = NewFileWatcher(certFile, reloader.reload) + reloader.certWatcher, err = NewFileWatcher(reloader.logger, certFile, reloader.reload) if err != nil { return nil, err } - reloader.keyWatcher, err = NewFileWatcher(keyFile, reloader.reload) + reloader.keyWatcher, err = NewFileWatcher(reloader.logger, keyFile, reloader.reload) if err != nil { reloader.certWatcher.Close() // nolint return nil, err @@ -72,10 +74,10 @@ func (r *CertificateReloader) Close() { } func (r *CertificateReloader) reload(filename string) { - log.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) + r.logger.Printf("reloading certificate from %s with %s", r.certFile, r.keyFile) pair, err := tls.LoadX509KeyPair(r.certFile, r.keyFile) if err != nil { - log.Printf("could not load certificate / key: %s", err) + r.logger.Printf("could not load certificate / key: %s", err) return } @@ -100,6 +102,8 @@ func (r *CertificateReloader) GetReloadCounter() uint64 { } type CertPoolReloader struct { + logger Logger + certFile string certWatcher *FileWatcher @@ -122,17 +126,18 @@ func loadCertPool(filename string) (*x509.CertPool, error) { return pool, nil } -func NewCertPoolReloader(certFile string) (*CertPoolReloader, error) { +func NewCertPoolReloader(logger Logger, certFile string) (*CertPoolReloader, error) { pool, err := loadCertPool(certFile) if err != nil { return nil, err } reloader := &CertPoolReloader{ + logger: logger, certFile: certFile, } reloader.pool.Store(pool) - reloader.certWatcher, err = NewFileWatcher(certFile, reloader.reload) + reloader.certWatcher, err = NewFileWatcher(reloader.logger, certFile, reloader.reload) if err != nil { return nil, err } @@ -145,10 +150,10 @@ func (r *CertPoolReloader) Close() { } func (r *CertPoolReloader) reload(filename string) { - log.Printf("reloading certificate pool from %s", r.certFile) + r.logger.Printf("reloading certificate pool from %s", r.certFile) pool, err := loadCertPool(r.certFile) if err != nil { - log.Printf("could not load certificate pool: %s", err) + r.logger.Printf("could not load certificate pool: %s", err) return } diff --git a/client.go b/client.go index b7d7d38..7ffff52 100644 --- a/client.go +++ b/client.go @@ -26,7 +26,6 @@ import ( "context" "encoding/json" "errors" - "log" "net" "strconv" "strings" @@ -121,6 +120,7 @@ type ClientGeoIpHandler interface { } type Client struct { + logger Logger ctx context.Context conn *websocket.Conn addr string @@ -163,6 +163,7 @@ func NewClient(ctx context.Context, conn *websocket.Conn, remoteAddress string, } func (c *Client) SetConn(ctx context.Context, conn *websocket.Conn, remoteAddress string, handler ClientHandler) { + c.logger = LoggerFromContext(ctx) c.ctx = ctx c.conn = conn c.addr = remoteAddress @@ -332,7 +333,7 @@ func (c *Client) ReadPump() { conn := c.conn c.mu.Unlock() if conn == nil { - log.Printf("Connection from %s closed while starting readPump", addr) + c.logger.Printf("Connection from %s closed while starting readPump", addr) return } @@ -348,9 +349,9 @@ func (c *Client) ReadPump() { if c.logRTT { rtt_ms := rtt.Nanoseconds() / time.Millisecond.Nanoseconds() if sessionId := c.GetSessionId(); sessionId != "" { - log.Printf("Client %s has RTT of %d ms (%s)", sessionId, rtt_ms, rtt) + c.logger.Printf("Client %s has RTT of %d ms (%s)", sessionId, rtt_ms, rtt) } else { - log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) + c.logger.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) } } statsClientRTT.Observe(float64(rtt.Milliseconds())) @@ -371,9 +372,9 @@ func (c *Client) ReadPump() { websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { if sessionId := c.GetSessionId(); sessionId != "" { - log.Printf("Error reading from client %s: %v", sessionId, err) + c.logger.Printf("Error reading from client %s: %v", sessionId, err) } else { - log.Printf("Error reading from %s: %v", addr, err) + c.logger.Printf("Error reading from %s: %v", addr, err) } } break @@ -381,9 +382,9 @@ func (c *Client) ReadPump() { if messageType != websocket.TextMessage { if sessionId := c.GetSessionId(); sessionId != "" { - log.Printf("Unsupported message type %v from client %s", messageType, sessionId) + c.logger.Printf("Unsupported message type %v from client %s", messageType, sessionId) } else { - log.Printf("Unsupported message type %v from %s", messageType, addr) + c.logger.Printf("Unsupported message type %v from %s", messageType, addr) } c.SendError(InvalidFormat) continue @@ -392,9 +393,9 @@ func (c *Client) ReadPump() { decodeBuffer, err := bufferPool.ReadAll(reader) if err != nil { if sessionId := c.GetSessionId(); sessionId != "" { - log.Printf("Error reading message from client %s: %v", sessionId, err) + c.logger.Printf("Error reading message from client %s: %v", sessionId, err) } else { - log.Printf("Error reading message from %s: %v", addr, err) + c.logger.Printf("Error reading message from %s: %v", addr, err) } break } @@ -446,9 +447,9 @@ func (c *Client) writeInternal(message json.Marshaler) bool { } if sessionId := c.GetSessionId(); sessionId != "" { - log.Printf("Could not send message %+v to client %s: %v", message, sessionId, err) + c.logger.Printf("Could not send message %+v to client %s: %v", message, sessionId, err) } else { - log.Printf("Could not send message %+v to %s: %v", message, c.RemoteAddr(), err) + c.logger.Printf("Could not send message %+v to %s: %v", message, c.RemoteAddr(), err) } closeData = websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") goto close @@ -459,9 +460,9 @@ close: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil { if sessionId := c.GetSessionId(); sessionId != "" { - log.Printf("Could not send close message to client %s: %v", sessionId, err) + c.logger.Printf("Could not send close message to client %s: %v", sessionId, err) } else { - log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) + c.logger.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) } } return false @@ -486,9 +487,9 @@ func (c *Client) writeError(e error) bool { // nolint c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil { if sessionId := c.GetSessionId(); sessionId != "" { - log.Printf("Could not send close message to client %s: %v", sessionId, err) + c.logger.Printf("Could not send close message to client %s: %v", sessionId, err) } else { - log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) + c.logger.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) } } return false @@ -534,9 +535,9 @@ func (c *Client) sendPing() bool { c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { if sessionId := c.GetSessionId(); sessionId != "" { - log.Printf("Could not send ping to client %s: %v", sessionId, err) + c.logger.Printf("Could not send ping to client %s: %v", sessionId, err) } else { - log.Printf("Could not send ping to %s: %v", c.RemoteAddr(), err) + c.logger.Printf("Could not send ping to %s: %v", c.RemoteAddr(), err) } return false } diff --git a/clientsession.go b/clientsession.go index 832bf15..a22af47 100644 --- a/clientsession.go +++ b/clientsession.go @@ -25,7 +25,6 @@ import ( "context" "encoding/json" "fmt" - "log" "maps" "net/url" "slices" @@ -55,6 +54,7 @@ const ( type ResponseHandlerFunc func(message *ClientMessage) bool type ClientSession struct { + logger Logger hub *Hub events AsyncEvents privateId PrivateSessionId @@ -119,8 +119,10 @@ type ClientSession struct { } func NewClientSession(hub *Hub, privateId PrivateSessionId, publicId PublicSessionId, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) { - ctx, closeFunc := context.WithCancel(context.Background()) + ctx := NewLoggerContext(context.Background(), hub.logger) + ctx, closeFunc := context.WithCancel(ctx) s := &ClientSession{ + logger: hub.logger, hub: hub, events: hub.events, privateId: privateId, @@ -276,7 +278,7 @@ func (s *ClientSession) SetPermissions(permissions []Permission) { s.permissions = p s.supportsPermissions = true - log.Printf("Permissions of session %s changed: %s", s.PublicId(), permissions) + s.logger.Printf("Permissions of session %s changed: %s", s.PublicId(), permissions) } func (s *ClientSession) Backend() *Backend { @@ -443,19 +445,19 @@ func (s *ClientSession) UpdateRoomSessionId(roomSessionId RoomSessionId) error { if roomSessionId != "" { if room := s.GetRoom(); room != nil { - log.Printf("Session %s updated room session id to %s in room %s", s.PublicId(), roomSessionId, room.Id()) + s.logger.Printf("Session %s updated room session id to %s in room %s", s.PublicId(), roomSessionId, room.Id()) } else if client := s.GetFederationClient(); client != nil { - log.Printf("Session %s updated room session id to %s in federated room %s", s.PublicId(), roomSessionId, client.RemoteRoomId()) + s.logger.Printf("Session %s updated room session id to %s in federated room %s", s.PublicId(), roomSessionId, client.RemoteRoomId()) } else { - log.Printf("Session %s updated room session id to %s in unknown room", s.PublicId(), roomSessionId) + s.logger.Printf("Session %s updated room session id to %s in unknown room", s.PublicId(), roomSessionId) } } else { if room := s.GetRoom(); room != nil { - log.Printf("Session %s cleared room session id in room %s", s.PublicId(), room.Id()) + s.logger.Printf("Session %s cleared room session id in room %s", s.PublicId(), room.Id()) } else if client := s.GetFederationClient(); client != nil { - log.Printf("Session %s cleared room session id in federated room %s", s.PublicId(), client.RemoteRoomId()) + s.logger.Printf("Session %s cleared room session id in federated room %s", s.PublicId(), client.RemoteRoomId()) } else { - log.Printf("Session %s cleared room session id in unknown room", s.PublicId()) + s.logger.Printf("Session %s cleared room session id in unknown room", s.PublicId()) } } @@ -477,7 +479,7 @@ func (s *ClientSession) SubscribeRoomEvents(roomid string, roomSessionId RoomSes return err } } - log.Printf("Session %s joined room %s with room session id %s", s.PublicId(), roomid, roomSessionId) + s.logger.Printf("Session %s joined room %s with room session id %s", s.PublicId(), roomid, roomSessionId) s.roomSessionId = roomSessionId return nil } @@ -491,7 +493,7 @@ func (s *ClientSession) LeaveCall() { return } - log.Printf("Session %s left call %s", s.PublicId(), room.Id()) + s.logger.Printf("Session %s left call %s", s.PublicId(), room.Id()) s.releaseMcuObjects() } @@ -503,7 +505,7 @@ func (s *ClientSession) LeaveRoomWithMessage(notify bool, message *ClientMessage if prev := s.federation.Swap(nil); prev != nil { // Session was connected to a federation room. if err := prev.Leave(message); err != nil { - log.Printf("Error leaving room for session %s on federation client %s: %s", s.PublicId(), prev.URL(), err) + s.logger.Printf("Error leaving room for session %s on federation client %s: %s", s.PublicId(), prev.URL(), err) prev.Close() } return nil @@ -548,15 +550,15 @@ func (s *ClientSession) doUnsubscribeRoomEvents(notify bool) { if notify && room != nil && s.roomSessionId != "" && !s.roomSessionId.IsFederated() { // Notify go func(sid RoomSessionId) { - ctx := context.Background() + ctx := NewLoggerContext(context.Background(), s.logger) request := NewBackendClientRoomRequest(room.Id(), s.userId, sid) request.Room.UpdateFromSession(s) request.Room.Action = "leave" var response api.StringMap if err := s.hub.backend.PerformJSONRequest(ctx, s.ParsedBackendOcsUrl(), request, &response); err != nil { - log.Printf("Could not notify about room session %s left room %s: %s", sid, room.Id(), err) + s.logger.Printf("Could not notify about room session %s left room %s: %s", sid, room.Id(), err) } else { - log.Printf("Removed room session %s: %+v", sid, response) + s.logger.Printf("Removed room session %s: %+v", sid, response) } }(s.roomSessionId) } @@ -575,7 +577,7 @@ func (s *ClientSession) clearClientLocked(client HandlerClient) { if s.client == nil { return } else if client != nil && s.client != client { - log.Printf("Trying to clear other client in session %s", s.PublicId()) + s.logger.Printf("Trying to clear other client in session %s", s.PublicId()) return } @@ -630,7 +632,7 @@ func (s *ClientSession) sendOffer(client McuClient, sender PublicSessionId, stre } offer_data, err := json.Marshal(offer_message) if err != nil { - log.Println("Could not serialize offer", offer_message, err) + s.logger.Println("Could not serialize offer", offer_message, err) return } response_message := &ServerMessage{ @@ -661,7 +663,7 @@ func (s *ClientSession) sendCandidate(client McuClient, sender PublicSessionId, } candidate_data, err := json.Marshal(candidate_message) if err != nil { - log.Println("Could not serialize candidate", candidate_message, err) + s.logger.Println("Could not serialize candidate", candidate_message, err) return } response_message := &ServerMessage{ @@ -750,7 +752,7 @@ func (s *ClientSession) OnIceCandidate(client McuClient, candidate any) { } } - log.Printf("Session %s received candidate %+v for unknown client %s", s.PublicId(), candidate, client.Id()) + s.logger.Printf("Session %s received candidate %+v for unknown client %s", s.PublicId(), candidate, client.Id()) } func (s *ClientSession) OnIceCompleted(client McuClient) { @@ -942,7 +944,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea } else { s.publishers[streamType] = publisher } - log.Printf("Publishing %s as %s for session %s", streamType, publisher.Id(), s.PublicId()) + s.logger.Printf("Publishing %s as %s for session %s", streamType, publisher.Id(), s.PublicId()) s.publisherWaiters.Wakeup() } else { publisher.SetMedia(mediaTypes) @@ -1021,7 +1023,7 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id P } else { s.subscribers[getStreamId(id, streamType)] = subscriber } - log.Printf("Subscribing %s from %s as %s in session %s", streamType, id, subscriber.Id(), s.PublicId()) + s.logger.Printf("Subscribing %s from %s as %s in session %s", streamType, id, subscriber.Id(), s.PublicId()) } return subscriber, nil @@ -1059,7 +1061,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { if (publisher.HasMedia(MediaTypeAudio) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_AUDIO)) || (publisher.HasMedia(MediaTypeVideo) && !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_VIDEO)) { delete(s.publishers, StreamTypeVideo) - log.Printf("Session %s is no longer allowed to publish media, closing publisher %s", s.PublicId(), publisher.Id()) + s.logger.Printf("Session %s is no longer allowed to publish media, closing publisher %s", s.PublicId(), publisher.Id()) go func() { publisher.Close(context.Background()) }() @@ -1070,7 +1072,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { if publisher, found := s.publishers[StreamTypeScreen]; found { delete(s.publishers, StreamTypeScreen) - log.Printf("Session %s is no longer allowed to publish screen, closing publisher %s", s.PublicId(), publisher.Id()) + s.logger.Printf("Session %s is no longer allowed to publish screen, closing publisher %s", s.PublicId(), publisher.Id()) go func() { publisher.Close(context.Background()) }() @@ -1081,7 +1083,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { return case "message": if message.Message.Type == "bye" && message.Message.Bye.Reason == "room_session_reconnected" { - log.Printf("Closing session %s because same room session %s connected", s.PublicId(), s.RoomSessionId()) + s.logger.Printf("Closing session %s because same room session %s connected", s.PublicId(), s.RoomSessionId()) s.LeaveRoom(false) defer s.closeAndWait(false) } @@ -1093,7 +1095,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, StreamType(message.SendOffer.Data.RoomType)) if err != nil { - log.Printf("Could not create MCU subscriber for session %s to process sendoffer in %s: %s", message.SendOffer.SessionId, s.PublicId(), err) + s.logger.Printf("Could not create MCU subscriber for session %s to process sendoffer in %s: %s", message.SendOffer.SessionId, s.PublicId(), err) if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{ Type: "message", Message: &ServerMessage{ @@ -1102,11 +1104,11 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { Error: NewError("client_not_found", "No MCU client found to send message to."), }, }); err != nil { - log.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err) + s.logger.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err) } return } else if mc == nil { - log.Printf("No MCU subscriber found for session %s to process sendoffer in %s", message.SendOffer.SessionId, s.PublicId()) + s.logger.Printf("No MCU subscriber found for session %s to process sendoffer in %s", message.SendOffer.SessionId, s.PublicId()) if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{ Type: "message", Message: &ServerMessage{ @@ -1115,14 +1117,14 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { Error: NewError("client_not_found", "No MCU client found to send message to."), }, }); err != nil { - log.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err) + s.logger.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err) } return } mc.SendMessage(s.Context(), nil, message.SendOffer.Data, func(err error, response api.StringMap) { if err != nil { - log.Printf("Could not send MCU message %+v for session %s to %s: %s", message.SendOffer.Data, message.SendOffer.SessionId, s.PublicId(), err) + s.logger.Printf("Could not send MCU message %+v for session %s to %s: %s", message.SendOffer.Data, message.SendOffer.SessionId, s.PublicId(), err) if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{ Type: "message", Message: &ServerMessage{ @@ -1131,7 +1133,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { Error: NewError("processing_failed", "Processing of the message failed, please check server logs."), }, }); err != nil { - log.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err) + s.logger.Printf("Error sending sendoffer error response to %s: %s", message.SendOffer.SessionId, err) } return } else if response == nil { @@ -1172,7 +1174,7 @@ func (s *ClientSession) storePendingMessage(message *ServerMessage) { } s.pendingClientMessages = append(s.pendingClientMessages, message) if len(s.pendingClientMessages) >= warnPendingMessagesCount { - log.Printf("Session %s has %d pending messages", s.PublicId(), len(s.pendingClientMessages)) + s.logger.Printf("Session %s has %d pending messages", s.PublicId(), len(s.pendingClientMessages)) } } @@ -1227,7 +1229,7 @@ func (s *ClientSession) filterDuplicateJoin(entries []*EventServerMessageSession result := make([]*EventServerMessageSessionEntry, 0, len(entries)) for _, e := range entries { if s.seenJoinedEvents[e.SessionId] { - log.Printf("Session %s got duplicate joined event for %s, ignoring", s.publicId, e.SessionId) + s.logger.Printf("Session %s got duplicate joined event for %s, ignoring", s.publicId, e.SessionId) continue } @@ -1383,7 +1385,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage { switch msg.Type { case "message": if msg.Message == nil { - log.Printf("Received asynchronous message without payload: %+v", msg) + s.logger.Printf("Received asynchronous message without payload: %+v", msg) return nil } @@ -1423,7 +1425,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage { // Can happen mostly during tests where an older room async message // could be received by a subscriber that joined after it was sent. if joined := s.getRoomJoinTime(); joined.IsZero() || msg.SendTime.Before(joined) { - log.Printf("Message %+v was sent on %s before room was joined on %s, ignoring", msg.Message, msg.SendTime, joined) + s.logger.Printf("Message %+v was sent on %s before room was joined on %s, ignoring", msg.Message, msg.SendTime, joined) return nil } } @@ -1431,7 +1433,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage { return msg.Message default: - log.Printf("Received async message with unsupported type %s: %+v", msg.Type, msg) + s.logger.Printf("Received async message with unsupported type %s: %+v", msg.Type, msg) return nil } } @@ -1453,7 +1455,7 @@ func (s *ClientSession) NotifySessionResumed(client HandlerClient) { s.hasPendingParticipantsUpdate = false s.mu.Unlock() - log.Printf("Send %d pending messages to session %s", len(messages), s.PublicId()) + s.logger.Printf("Send %d pending messages to session %s", len(messages), s.PublicId()) // Send through session to handle connection interruptions. s.SendMessages(messages) diff --git a/clientsession_test.go b/clientsession_test.go index a9e2c72..b0f0cb4 100644 --- a/clientsession_test.go +++ b/clientsession_test.go @@ -37,7 +37,6 @@ import ( func TestBandwidth_Client(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -45,8 +44,7 @@ func TestBandwidth_Client(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) require.NoError(mcu.Start(ctx)) defer mcu.Stop() @@ -86,7 +84,6 @@ func TestBandwidth_Client(t *testing.T) { func TestBandwidth_Backend(t *testing.T) { t.Parallel() - CatchLogForTest(t) hub, _, _, server := CreateHubWithMultipleBackendsForTest(t) u, err := url.Parse(server.URL + "/one") @@ -100,8 +97,7 @@ func TestBandwidth_Backend(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(t, err) + mcu := NewTestMCU(t) require.NoError(t, mcu.Start(ctx)) defer mcu.Stop() @@ -167,7 +163,6 @@ func TestBandwidth_Backend(t *testing.T) { func TestFeatureChatRelay(t *testing.T) { t.Parallel() - CatchLogForTest(t) testFunc := func(feature bool) func(t *testing.T) { return func(t *testing.T) { @@ -253,11 +248,8 @@ func TestFeatureChatRelay(t *testing.T) { } func TestFeatureChatRelayFederation(t *testing.T) { - CatchLogForTest(t) - var testFunc = func(feature bool) func(t *testing.T) { return func(t *testing.T) { - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) @@ -466,7 +458,6 @@ func TestFeatureChatRelayFederation(t *testing.T) { func TestPermissionHideDisplayNames(t *testing.T) { t.Parallel() - CatchLogForTest(t) testFunc := func(permission bool) func(t *testing.T) { return func(t *testing.T) { diff --git a/deferred_executor.go b/deferred_executor.go index c193eee..3162058 100644 --- a/deferred_executor.go +++ b/deferred_executor.go @@ -22,7 +22,6 @@ package signaling import ( - "log" "reflect" "runtime" "runtime/debug" @@ -32,16 +31,18 @@ import ( // DeferredExecutor will asynchronously execute functions while maintaining // their order. type DeferredExecutor struct { + logger Logger queue chan func() closed chan struct{} closeOnce sync.Once } -func NewDeferredExecutor(queueSize int) *DeferredExecutor { +func NewDeferredExecutor(logger Logger, queueSize int) *DeferredExecutor { if queueSize < 0 { queueSize = 0 } result := &DeferredExecutor{ + logger: logger, queue: make(chan func(), queueSize), closed: make(chan struct{}), } @@ -68,9 +69,9 @@ func getFunctionName(i any) string { func (e *DeferredExecutor) Execute(f func()) { defer func() { - if e := recover(); e != nil { - log.Printf("Could not defer function %v: %+v", getFunctionName(f), e) - log.Printf("Called from %s", string(debug.Stack())) + if err := recover(); err != nil { + e.logger.Printf("Could not defer function %v: %+v", getFunctionName(f), err) + e.logger.Printf("Called from %s", string(debug.Stack())) } }() diff --git a/deferred_executor_test.go b/deferred_executor_test.go index ed71c0c..183b2ae 100644 --- a/deferred_executor_test.go +++ b/deferred_executor_test.go @@ -29,7 +29,8 @@ import ( ) func TestDeferredExecutor_MultiClose(t *testing.T) { - e := NewDeferredExecutor(0) + logger := NewLoggerForTest(t) + e := NewDeferredExecutor(logger, 0) defer e.waitForStop() e.Close() @@ -38,7 +39,8 @@ func TestDeferredExecutor_MultiClose(t *testing.T) { func TestDeferredExecutor_QueueSize(t *testing.T) { SynctestTest(t, func(t *testing.T) { - e := NewDeferredExecutor(0) + logger := NewLoggerForTest(t) + e := NewDeferredExecutor(logger, 0) defer e.waitForStop() defer e.Close() @@ -59,7 +61,8 @@ func TestDeferredExecutor_QueueSize(t *testing.T) { } func TestDeferredExecutor_Order(t *testing.T) { - e := NewDeferredExecutor(64) + logger := NewLoggerForTest(t) + e := NewDeferredExecutor(logger, 64) defer e.waitForStop() defer e.Close() @@ -86,7 +89,8 @@ func TestDeferredExecutor_Order(t *testing.T) { } func TestDeferredExecutor_CloseFromFunc(t *testing.T) { - e := NewDeferredExecutor(64) + logger := NewLoggerForTest(t) + e := NewDeferredExecutor(logger, 64) defer e.waitForStop() done := make(chan struct{}) @@ -99,8 +103,8 @@ func TestDeferredExecutor_CloseFromFunc(t *testing.T) { } func TestDeferredExecutor_DeferAfterClose(t *testing.T) { - CatchLogForTest(t) - e := NewDeferredExecutor(64) + logger := NewLoggerForTest(t) + e := NewDeferredExecutor(logger, 64) defer e.waitForStop() e.Close() @@ -111,7 +115,8 @@ func TestDeferredExecutor_DeferAfterClose(t *testing.T) { } func TestDeferredExecutor_WaitForStopTwice(t *testing.T) { - e := NewDeferredExecutor(64) + logger := NewLoggerForTest(t) + e := NewDeferredExecutor(logger, 64) defer e.waitForStop() e.Close() diff --git a/dns_monitor.go b/dns_monitor.go index dcfba64..6906057 100644 --- a/dns_monitor.go +++ b/dns_monitor.go @@ -23,7 +23,6 @@ package signaling import ( "context" - "log" "net" "net/url" "slices" @@ -159,6 +158,7 @@ func (e *dnsMonitorEntry) runCallbacks(all []net.IP, add []net.IP, keep []net.IP } type DnsMonitor struct { + logger Logger interval time.Duration stopCtx context.Context @@ -176,13 +176,14 @@ type DnsMonitor struct { checkHostnames func() } -func NewDnsMonitor(interval time.Duration) (*DnsMonitor, error) { +func NewDnsMonitor(logger Logger, interval time.Duration) (*DnsMonitor, error) { if interval < 0 { interval = defaultDnsMonitorInterval } stopCtx, stopFunc := context.WithCancel(context.Background()) monitor := &DnsMonitor{ + logger: logger, interval: interval, stopCtx: stopCtx, @@ -348,7 +349,7 @@ func (m *DnsMonitor) checkHostname(entry *dnsMonitorEntry) { ips, err := lookupDnsMonitorIP(entry.hostname) if err != nil { - log.Printf("Could not lookup %s: %s", entry.hostname, err) + m.logger.Printf("Could not lookup %s: %s", entry.hostname, err) return } diff --git a/dns_monitor_test.go b/dns_monitor_test.go index daa92ef..7d43dd7 100644 --- a/dns_monitor_test.go +++ b/dns_monitor_test.go @@ -90,7 +90,8 @@ func newDnsMonitorForTest(t *testing.T, interval time.Duration) *DnsMonitor { t.Helper() require := require.New(t) - monitor, err := NewDnsMonitor(interval) + logger := NewLoggerForTest(t) + monitor, err := NewDnsMonitor(logger, interval) require.NoError(err) t.Cleanup(func() { diff --git a/etcd_client.go b/etcd_client.go index 1780679..15ce75e 100644 --- a/etcd_client.go +++ b/etcd_client.go @@ -25,7 +25,6 @@ import ( "context" "errors" "fmt" - "log" "slices" "sync" "sync/atomic" @@ -53,6 +52,7 @@ type EtcdClientWatcher interface { } type EtcdClient struct { + logger Logger compatSection string mu sync.Mutex @@ -61,8 +61,9 @@ type EtcdClient struct { listeners map[EtcdClientListener]bool } -func NewEtcdClient(config *goconf.ConfigFile, compatSection string) (*EtcdClient, error) { +func NewEtcdClient(logger Logger, config *goconf.ConfigFile, compatSection string) (*EtcdClient, error) { result := &EtcdClient{ + logger: logger, compatSection: compatSection, } if err := result.load(config, false); err != nil { @@ -96,7 +97,7 @@ func (c *EtcdClient) getConfigStringWithFallback(config *goconf.ConfigFile, opti if value == "" && c.compatSection != "" { value, _ = config.GetString(c.compatSection, option) if value != "" { - log.Printf("WARNING: Configuring etcd option \"%s\" in section \"%s\" is deprecated, use section \"etcd\" instead", option, c.compatSection) + c.logger.Printf("WARNING: Configuring etcd option \"%s\" in section \"%s\" is deprecated, use section \"etcd\" instead", option, c.compatSection) } } @@ -124,7 +125,7 @@ func (c *EtcdClient) load(config *goconf.ConfigFile, ignoreErrors bool) error { return nil } - log.Printf("No etcd endpoints configured, not changing client") + c.logger.Printf("No etcd endpoints configured, not changing client") } else { cfg := clientv3.Config{ Endpoints: endpoints, @@ -159,7 +160,7 @@ func (c *EtcdClient) load(config *goconf.ConfigFile, ignoreErrors bool) error { return fmt.Errorf("could not setup etcd TLS configuration: %w", err) } - log.Printf("Could not setup TLS configuration, will be disabled (%s)", err) + c.logger.Printf("Could not setup TLS configuration, will be disabled (%s)", err) } else { cfg.TLS = tlsConfig } @@ -171,14 +172,14 @@ func (c *EtcdClient) load(config *goconf.ConfigFile, ignoreErrors bool) error { return err } - log.Printf("Could not create new client from etd endpoints %+v: %s", endpoints, err) + c.logger.Printf("Could not create new client from etd endpoints %+v: %s", endpoints, err) } else { prev := c.getEtcdClient() if prev != nil { prev.Close() } c.client.Store(client) - log.Printf("Using etcd endpoints %+v", endpoints) + c.logger.Printf("Using etcd endpoints %+v", endpoints) c.notifyListeners() } } @@ -259,16 +260,16 @@ func (c *EtcdClient) WaitForConnection(ctx context.Context) error { if errors.Is(err, context.Canceled) { return err } else if errors.Is(err, context.DeadlineExceeded) { - log.Printf("Timeout waiting for etcd client to connect to the cluster, retry in %s", backoff.NextWait()) + c.logger.Printf("Timeout waiting for etcd client to connect to the cluster, retry in %s", backoff.NextWait()) } else { - log.Printf("Could not sync etcd client with the cluster, retry in %s: %s", backoff.NextWait(), err) + c.logger.Printf("Could not sync etcd client with the cluster, retry in %s: %s", backoff.NextWait(), err) } backoff.Wait(ctx) continue } - log.Printf("Client synced, using endpoints %+v", c.getEtcdClient().Endpoints()) + c.logger.Printf("Client synced, using endpoints %+v", c.getEtcdClient().Endpoints()) return nil } } @@ -278,10 +279,10 @@ func (c *EtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOpt } func (c *EtcdClient) Watch(ctx context.Context, key string, nextRevision int64, watcher EtcdClientWatcher, opts ...clientv3.OpOption) (int64, error) { - log.Printf("Wait for leader and start watching on %s (rev=%d)", key, nextRevision) + c.logger.Printf("Wait for leader and start watching on %s (rev=%d)", key, nextRevision) opts = append(opts, clientv3.WithRev(nextRevision), clientv3.WithPrevKV()) ch := c.getEtcdClient().Watch(clientv3.WithRequireLeader(ctx), key, opts...) - log.Printf("Watch created for %s", key) + c.logger.Printf("Watch created for %s", key) watcher.EtcdWatchCreated(c, key) for response := range ch { if err := response.Err(); err != nil { @@ -304,7 +305,7 @@ func (c *EtcdClient) Watch(ctx context.Context, key string, nextRevision int64, } watcher.EtcdKeyDeleted(c, string(ev.Kv.Key), prevValue) default: - log.Printf("Unsupported watch event %s %q -> %q", ev.Type, ev.Kv.Key, ev.Kv.Value) + c.logger.Printf("Unsupported watch event %s %q -> %q", ev.Type, ev.Kv.Key, ev.Kv.Value) } } } diff --git a/etcd_client_test.go b/etcd_client_test.go index c48e580..8c2e07d 100644 --- a/etcd_client_test.go +++ b/etcd_client_test.go @@ -154,7 +154,8 @@ func NewEtcdClientForTest(t *testing.T) (*embed.Etcd, *EtcdClient) { config.AddOption("etcd", "endpoints", etcd.Config().ListenClientUrls[0].String()) config.AddOption("etcd", "loglevel", "error") - client, err := NewEtcdClient(config, "") + logger := NewLoggerForTest(t) + client, err := NewEtcdClient(logger, config, "") require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, client.Close()) @@ -172,7 +173,8 @@ func NewEtcdClientWithTLSForTest(t *testing.T) (*embed.Etcd, *EtcdClient) { config.AddOption("etcd", "clientcert", certfile) config.AddOption("etcd", "cacert", certfile) - client, err := NewEtcdClient(config, "") + logger := NewLoggerForTest(t) + client, err := NewEtcdClient(logger, config, "") require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, client.Close()) @@ -196,12 +198,13 @@ func DeleteEtcdValue(etcd *embed.Etcd, key string) { func Test_EtcdClient_Get(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) require := require.New(t) etcd, client := NewEtcdClientForTest(t) - ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() if info := client.GetServerInfoEtcd(); assert.NotNil(info) { @@ -226,13 +229,13 @@ func Test_EtcdClient_Get(t *testing.T) { } } - if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) { + if response, err := client.Get(ctx, "foo"); assert.NoError(err) { assert.EqualValues(0, response.Count) } SetEtcdValue(etcd, "foo", []byte("bar")) - if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) { + if response, err := client.Get(ctx, "foo"); assert.NoError(err) { if assert.EqualValues(1, response.Count) { assert.Equal("foo", string(response.Kvs[0].Key)) assert.Equal("bar", string(response.Kvs[0].Value)) @@ -242,12 +245,13 @@ func Test_EtcdClient_Get(t *testing.T) { func Test_EtcdClientTLS_Get(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) require := require.New(t) etcd, client := NewEtcdClientWithTLSForTest(t) - ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() if info := client.GetServerInfoEtcd(); assert.NotNil(info) { @@ -288,11 +292,12 @@ func Test_EtcdClientTLS_Get(t *testing.T) { func Test_EtcdClient_GetPrefix(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) etcd, client := NewEtcdClientForTest(t) - if response, err := client.Get(context.Background(), "foo"); assert.NoError(err) { + if response, err := client.Get(ctx, "foo"); assert.NoError(err) { assert.EqualValues(0, response.Count) } @@ -300,7 +305,7 @@ func Test_EtcdClient_GetPrefix(t *testing.T) { SetEtcdValue(etcd, "foo/lala", []byte("2")) SetEtcdValue(etcd, "lala/foo", []byte("3")) - if response, err := client.Get(context.Background(), "foo", clientv3.WithPrefix()); assert.NoError(err) { + if response, err := client.Get(ctx, "foo", clientv3.WithPrefix()); assert.NoError(err) { if assert.EqualValues(2, response.Count) { assert.Equal("foo", string(response.Kvs[0].Key)) assert.Equal("1", string(response.Kvs[0].Value)) @@ -399,13 +404,14 @@ func (l *EtcdClientTestListener) EtcdKeyDeleted(client *EtcdClient, key string, func Test_EtcdClient_Watch(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) etcd, client := NewEtcdClientForTest(t) SetEtcdValue(etcd, "foo/a", []byte("1")) - listener := NewEtcdClientTestListener(context.Background(), t) + listener := NewEtcdClientTestListener(ctx, t) defer listener.Close() client.AddListener(listener) diff --git a/federation.go b/federation.go index 218b09a..5fa9ec1 100644 --- a/federation.go +++ b/federation.go @@ -27,7 +27,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "net" "strconv" "strings" @@ -78,6 +77,7 @@ func getCloudUrl(s string) string { } type FederationClient struct { + logger Logger hub *Hub session *ClientSession message atomic.Pointer[ClientMessage] @@ -144,6 +144,7 @@ func NewFederationClient(ctx context.Context, hub *Hub, session *ClientSession, } result := &FederationClient{ + logger: hub.logger, hub: hub, session: session, @@ -203,7 +204,7 @@ func (c *FederationClient) CanReuse(federation *RoomFederationMessage) bool { } func (c *FederationClient) connect(ctx context.Context) error { - log.Printf("Creating federation connection to %s for %s", c.URL(), c.session.PublicId()) + c.logger.Printf("Creating federation connection to %s for %s", c.URL(), c.session.PublicId()) conn, response, err := c.dialer.DialContext(ctx, c.url, nil) if err != nil { return err @@ -220,13 +221,13 @@ func (c *FederationClient) connect(ctx context.Context) error { } if !supportsFederation { if err := conn.Close(); err != nil { - log.Printf("Error closing federation connection to %s: %s", c.URL(), err) + c.logger.Printf("Error closing federation connection to %s: %s", c.URL(), err) } return ErrFederationNotSupported } - log.Printf("Federation connection established to %s for %s", c.URL(), c.session.PublicId()) + c.logger.Printf("Federation connection established to %s for %s", c.URL(), c.session.PublicId()) c.mu.Lock() defer c.mu.Unlock() @@ -303,18 +304,18 @@ func (c *FederationClient) closeConnection(withBye bool) { if err := c.sendMessageLocked(&ClientMessage{ Type: "bye", }); err != nil && !isClosedError(err) { - log.Printf("Error sending bye on federation connection to %s: %s", c.URL(), err) + c.logger.Printf("Error sending bye on federation connection to %s: %s", c.URL(), err) } } closeMessage := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") deadline := time.Now().Add(writeWait) if err := c.conn.WriteControl(websocket.CloseMessage, closeMessage, deadline); err != nil && !isClosedError(err) { - log.Printf("Error sending close message on federation connection to %s: %s", c.URL(), err) + c.logger.Printf("Error sending close message on federation connection to %s: %s", c.URL(), err) } if err := c.conn.Close(); err != nil && !isClosedError(err) { - log.Printf("Error closing federation connection to %s: %s", c.URL(), err) + c.logger.Printf("Error closing federation connection to %s: %s", c.URL(), err) } c.conn = nil @@ -362,11 +363,12 @@ func (c *FederationClient) reconnect() { return } - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.hub.federationTimeout)) + ctx := NewLoggerContext(context.Background(), c.logger) + ctx, cancel := context.WithTimeout(ctx, time.Duration(c.hub.federationTimeout)) defer cancel() if err := c.connect(ctx); err != nil { - log.Printf("Error connecting to federation server %s for %s: %s", c.URL(), c.session.PublicId(), err) + c.logger.Printf("Error connecting to federation server %s for %s: %s", c.URL(), c.session.PublicId(), err) c.scheduleReconnect() return } @@ -390,7 +392,7 @@ func (c *FederationClient) readPump(conn *websocket.Conn) { } if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { - log.Printf("Error reading from %s for %s: %s", c.URL(), c.session.PublicId(), err) + c.logger.Printf("Error reading from %s for %s: %s", c.URL(), c.session.PublicId(), err) } c.scheduleReconnect() @@ -403,7 +405,7 @@ func (c *FederationClient) readPump(conn *websocket.Conn) { var msg ServerMessage if err := json.Unmarshal(data, &msg); err != nil { - log.Printf("Error unmarshalling %s from %s: %s", string(data), c.URL(), err) + c.logger.Printf("Error unmarshalling %s from %s: %s", string(data), c.URL(), err) continue } @@ -432,7 +434,7 @@ func (c *FederationClient) sendPing() { msg := strconv.FormatInt(now, 10) c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { - log.Printf("Could not send ping to federated client %s for %s: %v", c.URL(), c.session.PublicId(), err) + c.logger.Printf("Could not send ping to federated client %s for %s: %v", c.URL(), c.session.PublicId(), err) c.scheduleReconnectLocked() } } @@ -517,7 +519,7 @@ func (c *FederationClient) processWelcome(msg *ServerMessage) { Token: c.federation.Load().Token, } if err := c.sendHello(federationParams); err != nil { - log.Printf("Error sending hello message to %s for %s: %s", c.URL(), c.session.PublicId(), err) + c.logger.Printf("Error sending hello message to %s for %s: %s", c.URL(), c.session.PublicId(), err) c.closeWithError(err) } } @@ -529,7 +531,7 @@ func (c *FederationClient) processHello(msg *ServerMessage) { defer c.helloMu.Unlock() if msg.Id != c.helloMsgId { - log.Printf("Received hello response %+v for unknown request, expected %s", msg, c.helloMsgId) + c.logger.Printf("Received hello response %+v for unknown request, expected %s", msg, c.helloMsgId) if err := c.sendHelloLocked(c.helloAuth); err != nil { c.closeWithError(err) } @@ -548,12 +550,12 @@ func (c *FederationClient) processHello(msg *ServerMessage) { c.closeWithError(err) } default: - log.Printf("Received hello error from federated client for %s to %s: %+v", c.session.PublicId(), c.URL(), msg) + c.logger.Printf("Received hello error from federated client for %s to %s: %+v", c.session.PublicId(), c.URL(), msg) c.closeWithError(msg.Error) } return } else if msg.Type != "hello" { - log.Printf("Received unknown hello response from federated client for %s to %s: %+v", c.session.PublicId(), c.URL(), msg) + c.logger.Printf("Received unknown hello response from federated client for %s to %s: %+v", c.session.PublicId(), c.URL(), msg) if err := c.sendHelloLocked(c.helloAuth); err != nil { c.closeWithError(err) } @@ -594,7 +596,7 @@ func (c *FederationClient) processHello(msg *ServerMessage) { messages := c.pendingMessages c.pendingMessages = nil - log.Printf("Sending %d pending messages to %s for %s", count, c.URL(), c.session.PublicId()) + c.logger.Printf("Sending %d pending messages to %s for %s", count, c.URL(), c.session.PublicId()) c.helloMu.Unlock() defer c.helloMu.Lock() @@ -603,7 +605,7 @@ func (c *FederationClient) processHello(msg *ServerMessage) { defer c.mu.Unlock() for _, msg := range messages { if err := c.sendMessageLocked(msg); err != nil { - log.Printf("Error sending pending message %+v on federation connection to %s: %s", msg, c.URL(), err) + c.logger.Printf("Error sending pending message %+v on federation connection to %s: %s", msg, c.URL(), err) break } } @@ -966,7 +968,7 @@ func (c *FederationClient) deferMessage(message *ClientMessage) { c.pendingMessages = append(c.pendingMessages, message) if len(c.pendingMessages) >= warnPendingMessagesCount { - log.Printf("Session %s has %d pending federated messages", c.session.PublicId(), len(c.pendingMessages)) + c.logger.Printf("Session %s has %d pending federated messages", c.session.PublicId(), len(c.pendingMessages)) } } @@ -999,7 +1001,7 @@ func (c *FederationClient) sendMessageLocked(message *ClientMessage) error { return err } - log.Printf("Could not send message %+v for %s to federated client %s: %v", message, c.session.PublicId(), c.URL(), err) + c.logger.Printf("Could not send message %+v for %s to federated client %s: %v", message, c.session.PublicId(), c.URL(), err) c.deferMessage(message) c.scheduleReconnectLocked() } diff --git a/federation_test.go b/federation_test.go index cdfc720..e97996b 100644 --- a/federation_test.go +++ b/federation_test.go @@ -36,8 +36,6 @@ import ( ) func Test_FederationInvalidToken(t *testing.T) { - CatchLogForTest(t) - assert := assert.New(t) require := require.New(t) @@ -75,8 +73,6 @@ func Test_FederationInvalidToken(t *testing.T) { } func Test_Federation(t *testing.T) { - CatchLogForTest(t) - assert := assert.New(t) require := require.New(t) @@ -492,8 +488,6 @@ func Test_Federation(t *testing.T) { } func Test_FederationJoinRoomTwice(t *testing.T) { - CatchLogForTest(t) - assert := assert.New(t) require := require.New(t) @@ -599,8 +593,6 @@ func Test_FederationJoinRoomTwice(t *testing.T) { } func Test_FederationChangeRoom(t *testing.T) { - CatchLogForTest(t) - assert := assert.New(t) require := require.New(t) @@ -708,8 +700,6 @@ func Test_FederationChangeRoom(t *testing.T) { } func Test_FederationMedia(t *testing.T) { - CatchLogForTest(t) - assert := assert.New(t) require := require.New(t) @@ -718,15 +708,13 @@ func Test_FederationMedia(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu1, err := NewTestMCU() - require.NoError(err) + mcu1 := NewTestMCU(t) require.NoError(mcu1.Start(ctx)) defer mcu1.Stop() hub1.SetMcu(mcu1) - mcu2, err := NewTestMCU() - require.NoError(err) + mcu2 := NewTestMCU(t) require.NoError(mcu2.Start(ctx)) defer mcu2.Stop() @@ -815,8 +803,6 @@ func Test_FederationMedia(t *testing.T) { } func Test_FederationResume(t *testing.T) { - CatchLogForTest(t) - assert := assert.New(t) require := require.New(t) @@ -936,8 +922,6 @@ func Test_FederationResume(t *testing.T) { } func Test_FederationResumeNewSession(t *testing.T) { - CatchLogForTest(t) - assert := assert.New(t) require := require.New(t) diff --git a/file_watcher.go b/file_watcher.go index a26d0a7..489ed10 100644 --- a/file_watcher.go +++ b/file_watcher.go @@ -24,7 +24,6 @@ package signaling import ( "context" "errors" - "log" "os" "path" "path/filepath" @@ -51,6 +50,7 @@ func init() { type FileWatcherCallback func(filename string) type FileWatcher struct { + logger Logger filename string target string callback FileWatcherCallback @@ -60,7 +60,7 @@ type FileWatcher struct { closeFunc context.CancelFunc } -func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher, error) { +func NewFileWatcher(logger Logger, filename string, callback FileWatcherCallback) (*FileWatcher, error) { watcher, err := fsnotify.NewWatcher() if err != nil { return nil, err @@ -74,6 +74,7 @@ func NewFileWatcher(filename string, callback FileWatcherCallback) (*FileWatcher closeCtx, closeFunc := context.WithCancel(context.Background()) w := &FileWatcher{ + logger: logger, filename: filename, callback: callback, watcher: watcher, @@ -157,14 +158,14 @@ func (f *FileWatcher) run() { triggerEvent(event) if err := f.updateWatcher(); err != nil { - log.Printf("Error updating watcher after %s is deleted: %s", event.Name, err) + f.logger.Printf("Error updating watcher after %s is deleted: %s", event.Name, err) } continue } if stat, err := os.Lstat(event.Name); err != nil { if !errors.Is(err, os.ErrNotExist) { - log.Printf("Could not lstat %s: %s", event.Name, err) + f.logger.Printf("Could not lstat %s: %s", event.Name, err) } } else if stat.Mode()&os.ModeSymlink != 0 { target, err := filepath.EvalSymlinks(event.Name) @@ -183,7 +184,7 @@ func (f *FileWatcher) run() { return } - log.Printf("Error watching %s: %s", f.filename, err) + f.logger.Printf("Error watching %s: %s", f.filename, err) case <-f.closeCtx.Done(): return } diff --git a/file_watcher_test.go b/file_watcher_test.go index 8dfa48c..5a29a73 100644 --- a/file_watcher_test.go +++ b/file_watcher_test.go @@ -38,7 +38,8 @@ var ( func TestFileWatcher_NotExist(t *testing.T) { assert := assert.New(t) tmpdir := t.TempDir() - if w, err := NewFileWatcher(path.Join(tmpdir, "test.txt"), func(filename string) {}); !assert.ErrorIs(err, os.ErrNotExist) { + logger := NewLoggerForTest(t) + if w, err := NewFileWatcher(logger, path.Join(tmpdir, "test.txt"), func(filename string) {}); !assert.ErrorIs(err, os.ErrNotExist) { if w != nil { assert.NoError(w.Close()) } @@ -53,8 +54,9 @@ func TestFileWatcher_File(t *testing.T) { filename := path.Join(tmpdir, "test.txt") require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644)) + logger := NewLoggerForTest(t) modified := make(chan struct{}) - w, err := NewFileWatcher(filename, func(filename string) { + w, err := NewFileWatcher(logger, filename, func(filename string) { modified <- struct{}{} }) require.NoError(err) @@ -95,8 +97,9 @@ func TestFileWatcher_CurrentDir(t *testing.T) { filename := path.Join(tmpdir, "test.txt") require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644)) + logger := NewLoggerForTest(t) modified := make(chan struct{}) - w, err := NewFileWatcher("./"+path.Base(filename), func(filename string) { + w, err := NewFileWatcher(logger, "./"+path.Base(filename), func(filename string) { modified <- struct{}{} }) require.NoError(err) @@ -135,8 +138,9 @@ func TestFileWatcher_Rename(t *testing.T) { filename := path.Join(tmpdir, "test.txt") require.NoError(os.WriteFile(filename, []byte("Hello world!"), 0644)) + logger := NewLoggerForTest(t) modified := make(chan struct{}) - w, err := NewFileWatcher(filename, func(filename string) { + w, err := NewFileWatcher(logger, filename, func(filename string) { modified <- struct{}{} }) require.NoError(err) @@ -177,8 +181,9 @@ func TestFileWatcher_Symlink(t *testing.T) { filename := path.Join(tmpdir, "symlink.txt") require.NoError(os.Symlink(sourceFilename, filename)) + logger := NewLoggerForTest(t) modified := make(chan struct{}) - w, err := NewFileWatcher(filename, func(filename string) { + w, err := NewFileWatcher(logger, filename, func(filename string) { modified <- struct{}{} }) require.NoError(err) @@ -210,8 +215,9 @@ func TestFileWatcher_ChangeSymlinkTarget(t *testing.T) { filename := path.Join(tmpdir, "symlink.txt") require.NoError(os.Symlink(sourceFilename1, filename)) + logger := NewLoggerForTest(t) modified := make(chan struct{}) - w, err := NewFileWatcher(filename, func(filename string) { + w, err := NewFileWatcher(logger, filename, func(filename string) { modified <- struct{}{} }) require.NoError(err) @@ -245,8 +251,9 @@ func TestFileWatcher_OtherSymlink(t *testing.T) { filename := path.Join(tmpdir, "symlink.txt") require.NoError(os.Symlink(sourceFilename1, filename)) + logger := NewLoggerForTest(t) modified := make(chan struct{}) - w, err := NewFileWatcher(filename, func(filename string) { + w, err := NewFileWatcher(logger, filename, func(filename string) { modified <- struct{}{} }) require.NoError(err) @@ -274,8 +281,9 @@ func TestFileWatcher_RenameSymlinkTarget(t *testing.T) { filename := path.Join(tmpdir, "test.txt") require.NoError(os.Symlink(sourceFilename1, filename)) + logger := NewLoggerForTest(t) modified := make(chan struct{}) - w, err := NewFileWatcher(filename, func(filename string) { + w, err := NewFileWatcher(logger, filename, func(filename string) { modified <- struct{}{} }) require.NoError(err) @@ -326,8 +334,9 @@ func TestFileWatcher_UpdateSymlinkFolder(t *testing.T) { filename := path.Join(tmpdir, "test.txt") require.NoError(os.Symlink("data/test.txt", filename)) + logger := NewLoggerForTest(t) modified := make(chan struct{}) - w, err := NewFileWatcher(filename, func(filename string) { + w, err := NewFileWatcher(logger, filename, func(filename string) { modified <- struct{}{} }) require.NoError(err) diff --git a/geoip.go b/geoip.go index 1051324..f48092f 100644 --- a/geoip.go +++ b/geoip.go @@ -24,9 +24,9 @@ package signaling import ( "archive/tar" "compress/gzip" + "context" "fmt" "io" - "log" "net" "net/http" "net/url" @@ -56,6 +56,7 @@ func GetGeoIpDownloadUrl(license string) string { } type GeoLookup struct { + logger Logger url string isFile bool client http.Client @@ -66,15 +67,17 @@ type GeoLookup struct { reader atomic.Pointer[maxminddb.Reader] } -func NewGeoLookupFromUrl(url string) (*GeoLookup, error) { +func NewGeoLookupFromUrl(logger Logger, url string) (*GeoLookup, error) { geoip := &GeoLookup{ - url: url, + logger: logger, + url: url, } return geoip, nil } -func NewGeoLookupFromFile(filename string) (*GeoLookup, error) { +func NewGeoLookupFromFile(logger Logger, filename string) (*GeoLookup, error) { geoip := &GeoLookup{ + logger: logger, url: filename, isFile: true, } @@ -119,7 +122,7 @@ func (g *GeoLookup) updateFile() error { } metadata := reader.Metadata - log.Printf("Using %s GeoIP database from %s (built on %s)", metadata.DatabaseType, g.url, time.Unix(int64(metadata.BuildEpoch), 0).UTC()) + g.logger.Printf("Using %s GeoIP database from %s (built on %s)", metadata.DatabaseType, g.url, time.Unix(int64(metadata.BuildEpoch), 0).UTC()) if old := g.reader.Swap(reader); old != nil { old.Close() @@ -144,7 +147,7 @@ func (g *GeoLookup) updateUrl() error { defer response.Body.Close() if response.StatusCode == http.StatusNotModified { - log.Printf("GeoIP database at %s has not changed", g.url) + g.logger.Printf("GeoIP database at %s has not changed", g.url) return nil } else if response.StatusCode/100 != 2 { return fmt.Errorf("downloading %s returned an error: %s", g.url, response.Status) @@ -202,7 +205,7 @@ func (g *GeoLookup) updateUrl() error { } metadata := reader.Metadata - log.Printf("Using %s GeoIP database from %s (built on %s)", metadata.DatabaseType, g.url, time.Unix(int64(metadata.BuildEpoch), 0).UTC()) + g.logger.Printf("Using %s GeoIP database from %s (built on %s)", metadata.DatabaseType, g.url, time.Unix(int64(metadata.BuildEpoch), 0).UTC()) if old := g.reader.Swap(reader); old != nil { old.Close() @@ -268,7 +271,8 @@ func IsValidContinent(continent string) bool { } } -func LoadGeoIPOverrides(config *goconf.ConfigFile, ignoreErrors bool) (map[*net.IPNet]string, error) { +func LoadGeoIPOverrides(ctx context.Context, config *goconf.ConfigFile, ignoreErrors bool) (map[*net.IPNet]string, error) { + logger := LoggerFromContext(ctx) options, _ := GetStringOptions(config, "geoip-overrides", true) if len(options) == 0 { return nil, nil @@ -283,7 +287,7 @@ func LoadGeoIPOverrides(config *goconf.ConfigFile, ignoreErrors bool) (map[*net. _, ipNet, err = net.ParseCIDR(option) if err != nil { if ignoreErrors { - log.Printf("could not parse CIDR %s (%s), skipping", option, err) + logger.Printf("could not parse CIDR %s (%s), skipping", option, err) continue } @@ -293,7 +297,7 @@ func LoadGeoIPOverrides(config *goconf.ConfigFile, ignoreErrors bool) (map[*net. ip = net.ParseIP(option) if ip == nil { if ignoreErrors { - log.Printf("could not parse IP %s, skipping", option) + logger.Printf("could not parse IP %s, skipping", option) continue } @@ -314,14 +318,14 @@ func LoadGeoIPOverrides(config *goconf.ConfigFile, ignoreErrors bool) (map[*net. value = strings.ToUpper(strings.TrimSpace(value)) if value == "" { - log.Printf("IP %s doesn't have a country assigned, skipping", option) + logger.Printf("IP %s doesn't have a country assigned, skipping", option) continue } else if !IsValidCountry(value) { - log.Printf("Country %s for IP %s is invalid, skipping", value, option) + logger.Printf("Country %s for IP %s is invalid, skipping", value, option) continue } - log.Printf("Using country %s for %s", value, ipNet) + logger.Printf("Using country %s for %s", value, ipNet) geoipOverrides[ipNet] = value } diff --git a/geoip_test.go b/geoip_test.go index 0aa1aa7..885749a 100644 --- a/geoip_test.go +++ b/geoip_test.go @@ -76,9 +76,9 @@ func GetGeoIpUrlForTest(t *testing.T) string { } func TestGeoLookup(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) - reader, err := NewGeoLookupFromUrl(GetGeoIpUrlForTest(t)) + reader, err := NewGeoLookupFromUrl(logger, GetGeoIpUrlForTest(t)) require.NoError(err) defer reader.Close() @@ -88,9 +88,9 @@ func TestGeoLookup(t *testing.T) { } func TestGeoLookupCaching(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) - reader, err := NewGeoLookupFromUrl(GetGeoIpUrlForTest(t)) + reader, err := NewGeoLookupFromUrl(logger, GetGeoIpUrlForTest(t)) require.NoError(err) defer reader.Close() @@ -126,14 +126,14 @@ func TestGeoLookupContinent(t *testing.T) { } func TestGeoLookupCloseEmpty(t *testing.T) { - CatchLogForTest(t) - reader, err := NewGeoLookupFromUrl("ignore-url") + logger := NewLoggerForTest(t) + reader, err := NewGeoLookupFromUrl(logger, "ignore-url") require.NoError(t, err) reader.Close() } func TestGeoLookupFromFile(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) geoIpUrl := GetGeoIpUrlForTest(t) @@ -188,7 +188,7 @@ func TestGeoLookupFromFile(t *testing.T) { require.True(foundDatabase, "Did not find GeoIP database in download from %s", geoIpUrl) - reader, err := NewGeoLookupFromFile(tmpfile.Name()) + reader, err := NewGeoLookupFromFile(logger, tmpfile.Name()) require.NoError(err) defer reader.Close() diff --git a/grpc_client.go b/grpc_client.go index 38a4270..81b449e 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -27,7 +27,6 @@ import ( "errors" "fmt" "io" - "log" "net" "slices" "sync" @@ -79,6 +78,7 @@ func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl { } type GrpcClient struct { + logger Logger ip net.IP rawTarget string target string @@ -127,7 +127,7 @@ func (r *customIpResolver) Close() { // Noop } -func NewGrpcClient(target string, ip net.IP, opts ...grpc.DialOption) (*GrpcClient, error) { +func NewGrpcClient(logger Logger, target string, ip net.IP, opts ...grpc.DialOption) (*GrpcClient, error) { var conn *grpc.ClientConn var err error if ip != nil { @@ -153,6 +153,7 @@ func NewGrpcClient(target string, ip net.IP, opts ...grpc.DialOption) (*GrpcClie } result := &GrpcClient{ + logger: logger, ip: ip, rawTarget: target, target: target, @@ -200,7 +201,7 @@ func (c *GrpcClient) GetServerId(ctx context.Context) (string, string, error) { func (c *GrpcClient) LookupResumeId(ctx context.Context, resumeId PrivateSessionId) (*LookupResumeIdReply, error) { statsGrpcClientCalls.WithLabelValues("LookupResumeId").Inc() // TODO: Remove debug logging - log.Printf("Lookup resume id %s on %s", resumeId, c.Target()) + c.logger.Printf("Lookup resume id %s on %s", resumeId, c.Target()) response, err := c.impl.LookupResumeId(ctx, &LookupResumeIdRequest{ ResumeId: string(resumeId), }, grpc.WaitForReady(true)) @@ -220,7 +221,7 @@ func (c *GrpcClient) LookupResumeId(ctx context.Context, resumeId PrivateSession func (c *GrpcClient) LookupSessionId(ctx context.Context, roomSessionId RoomSessionId, disconnectReason string) (PublicSessionId, error) { statsGrpcClientCalls.WithLabelValues("LookupSessionId").Inc() // TODO: Remove debug logging - log.Printf("Lookup room session %s on %s", roomSessionId, c.Target()) + c.logger.Printf("Lookup room session %s on %s", roomSessionId, c.Target()) response, err := c.impl.LookupSessionId(ctx, &LookupSessionIdRequest{ RoomSessionId: string(roomSessionId), DisconnectReason: disconnectReason, @@ -242,7 +243,7 @@ func (c *GrpcClient) LookupSessionId(ctx context.Context, roomSessionId RoomSess func (c *GrpcClient) IsSessionInCall(ctx context.Context, sessionId PublicSessionId, room *Room, backendUrl string) (bool, error) { statsGrpcClientCalls.WithLabelValues("IsSessionInCall").Inc() // TODO: Remove debug logging - log.Printf("Check if session %s is in call %s on %s", sessionId, room.Id(), c.Target()) + c.logger.Printf("Check if session %s is in call %s on %s", sessionId, room.Id(), c.Target()) response, err := c.impl.IsSessionInCall(ctx, &IsSessionInCallRequest{ SessionId: string(sessionId), RoomId: room.Id(), @@ -260,7 +261,7 @@ func (c *GrpcClient) IsSessionInCall(ctx context.Context, sessionId PublicSessio func (c *GrpcClient) GetInternalSessions(ctx context.Context, roomId string, backendUrls []string) (internal map[PublicSessionId]*InternalSessionData, virtual map[PublicSessionId]*VirtualSessionData, err error) { statsGrpcClientCalls.WithLabelValues("GetInternalSessions").Inc() // TODO: Remove debug logging - log.Printf("Get internal sessions for %s on %s", roomId, c.Target()) + c.logger.Printf("Get internal sessions for %s on %s", roomId, c.Target()) var backendUrl string if len(backendUrls) > 0 { backendUrl = backendUrls[0] @@ -295,7 +296,7 @@ func (c *GrpcClient) GetInternalSessions(ctx context.Context, roomId string, bac func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId PublicSessionId, streamType StreamType) (PublicSessionId, string, net.IP, string, string, error) { statsGrpcClientCalls.WithLabelValues("GetPublisherId").Inc() // TODO: Remove debug logging - log.Printf("Get %s publisher id %s on %s", streamType, sessionId, c.Target()) + c.logger.Printf("Get %s publisher id %s on %s", streamType, sessionId, c.Target()) response, err := c.impl.GetPublisherId(ctx, &GetPublisherIdRequest{ SessionId: string(sessionId), StreamType: string(streamType), @@ -312,7 +313,7 @@ func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId PublicSession func (c *GrpcClient) GetSessionCount(ctx context.Context, url string) (uint32, error) { statsGrpcClientCalls.WithLabelValues("GetSessionCount").Inc() // TODO: Remove debug logging - log.Printf("Get session count for %s on %s", url, c.Target()) + c.logger.Printf("Get session count for %s on %s", url, c.Target()) response, err := c.impl.GetSessionCount(ctx, &GetSessionCountRequest{ Url: url, }, grpc.WaitForReady(true)) @@ -335,6 +336,7 @@ type ProxySessionReceiver interface { } type SessionProxy struct { + logger Logger sessionId PublicSessionId receiver ProxySessionReceiver @@ -347,7 +349,7 @@ func (p *SessionProxy) recvPump() { defer func() { p.receiver.OnProxyClose(closeError) if err := p.Close(); err != nil { - log.Printf("Error closing proxy for session %s: %s", p.sessionId, err) + p.logger.Printf("Error closing proxy for session %s: %s", p.sessionId, err) } }() @@ -358,13 +360,13 @@ func (p *SessionProxy) recvPump() { break } - log.Printf("Error receiving message from proxy for session %s: %s", p.sessionId, err) + p.logger.Printf("Error receiving message from proxy for session %s: %s", p.sessionId, err) closeError = err break } if err := p.receiver.OnProxyMessage(msg); err != nil { - log.Printf("Error processing message %+v from proxy for session %s: %s", msg, p.sessionId, err) + p.logger.Printf("Error processing message %+v from proxy for session %s: %s", msg, p.sessionId, err) } } } @@ -395,6 +397,7 @@ func (c *GrpcClient) ProxySession(ctx context.Context, sessionId PublicSessionId } proxy := &SessionProxy{ + logger: c.logger, sessionId: sessionId, receiver: receiver, @@ -413,6 +416,7 @@ type grpcClientsList struct { type GrpcClients struct { mu sync.RWMutex version string + logger Logger // +checklocks:mu clientsMap map[string]*grpcClientsList @@ -439,11 +443,12 @@ type GrpcClients struct { closeFunc context.CancelFunc // +checklocksignore: No locking necessary. } -func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient, dnsMonitor *DnsMonitor, version string) (*GrpcClients, error) { +func NewGrpcClients(ctx context.Context, config *goconf.ConfigFile, etcdClient *EtcdClient, dnsMonitor *DnsMonitor, version string) (*GrpcClients, error) { initializedCtx, initializedFunc := context.WithCancel(context.Background()) closeCtx, closeFunc := context.WithCancel(context.Background()) result := &GrpcClients{ version: version, + logger: LoggerFromContext(ctx), dnsMonitor: dnsMonitor, etcdClient: etcdClient, initializedCtx: initializedCtx, @@ -484,7 +489,7 @@ func (c *GrpcClients) GetServerInfoGrpc() (result []BackendServerInfoGrpc) { } func (c *GrpcClients) load(config *goconf.ConfigFile, fromReload bool) error { - creds, err := NewReloadableCredentials(config, false) + creds, err := NewReloadableCredentials(c.logger, config, false) if err != nil { return err } @@ -522,7 +527,7 @@ func (c *GrpcClients) closeClient(client *GrpcClient) { } if err := client.Close(); err != nil { - log.Printf("Error closing client to %s: %s", client.Target(), err) + c.logger.Printf("Error closing client to %s: %s", client.Target(), err) } } @@ -568,7 +573,7 @@ loop: } if status.Code(err) != codes.Canceled { - log.Printf("Error checking GRPC server id of %s, retrying in %s: %s", client.Target(), backoff.NextWait(), err) + c.logger.Printf("Error checking GRPC server id of %s, retrying in %s: %s", client.Target(), backoff.NextWait(), err) } backoff.Wait(ctx) continue @@ -576,13 +581,13 @@ loop: client.version.Store(version) if id == GrpcServerId { - log.Printf("GRPC target %s is this server, removing", client.Target()) + c.logger.Printf("GRPC target %s is this server, removing", client.Target()) c.closeClient(client) client.SetSelf(true) } else if version != c.version { - log.Printf("WARNING: Node %s is runing different version %s than local node (%s)", client.Target(), version, c.version) + c.logger.Printf("WARNING: Node %s is runing different version %s than local node (%s)", client.Target(), version, c.version) } else { - log.Printf("Checked GRPC server id of %s running version %s", client.Target(), version) + c.logger.Printf("Checked GRPC server id of %s running version %s", client.Target(), version) } break loop } @@ -648,7 +653,7 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo continue } - client, err := NewGrpcClient(target, nil, opts...) + client, err := NewGrpcClient(c.logger, target, nil, opts...) if err != nil { for _, entry := range clientsMap { for _, client := range entry.clients { @@ -666,7 +671,7 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo c.selfCheckWaitGroup.Add(1) go c.checkIsSelf(c.closeCtx, target, client) - log.Printf("Adding %s as GRPC target", client.Target()) + c.logger.Printf("Adding %s as GRPC target", client.Target()) entry, found := clientsMap[target] if !found { entry = &grpcClientsList{} @@ -679,7 +684,7 @@ func (c *GrpcClients) loadTargetsStatic(config *goconf.ConfigFile, fromReload bo for target := range removeTargets { if entry, found := clientsMap[target]; found { for _, client := range entry.clients { - log.Printf("Deleting GRPC target %s", client.Target()) + c.logger.Printf("Deleting GRPC target %s", client.Target()) c.closeClient(client) } @@ -716,7 +721,7 @@ func (c *GrpcClients) onLookup(entry *DnsMonitorEntry, all []net.IP, added []net for _, client := range e.clients { if ip.Equal(client.ip) { mapModified = true - log.Printf("Removing connection to %s", client.Target()) + c.logger.Printf("Removing connection to %s", client.Target()) c.closeClient(client) c.wakeupForTesting() } @@ -732,16 +737,16 @@ func (c *GrpcClients) onLookup(entry *DnsMonitorEntry, all []net.IP, added []net } for _, ip := range added { - client, err := NewGrpcClient(target, ip, opts...) + client, err := NewGrpcClient(c.logger, target, ip, opts...) if err != nil { - log.Printf("Error creating client to %s with IP %s: %s", target, ip.String(), err) + c.logger.Printf("Error creating client to %s with IP %s: %s", target, ip.String(), err) continue } c.selfCheckWaitGroup.Add(1) go c.checkIsSelf(c.closeCtx, target, client) - log.Printf("Adding %s as GRPC target", client.Target()) + c.logger.Printf("Adding %s as GRPC target", client.Target()) newClients = append(newClients, client) mapModified = true c.wakeupForTesting() @@ -797,9 +802,9 @@ func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) { if errors.Is(err, context.Canceled) { return } else if errors.Is(err, context.DeadlineExceeded) { - log.Printf("Timeout getting initial list of GRPC targets, retry in %s", backoff.NextWait()) + c.logger.Printf("Timeout getting initial list of GRPC targets, retry in %s", backoff.NextWait()) } else { - log.Printf("Could not get initial list of GRPC targets, retry in %s: %s", backoff.NextWait(), err) + c.logger.Printf("Could not get initial list of GRPC targets, retry in %s: %s", backoff.NextWait(), err) } backoff.Wait(c.closeCtx) @@ -819,7 +824,7 @@ func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) { for c.closeCtx.Err() == nil { var err error if nextRevision, err = client.Watch(c.closeCtx, c.targetPrefix, nextRevision, c, clientv3.WithPrefix()); err != nil { - log.Printf("Error processing watch for %s (%s), retry in %s", c.targetPrefix, err, backoff.NextWait()) + c.logger.Printf("Error processing watch for %s (%s), retry in %s", c.targetPrefix, err, backoff.NextWait()) backoff.Wait(c.closeCtx) continue } @@ -828,7 +833,7 @@ func (c *GrpcClients) EtcdClientCreated(client *EtcdClient) { backoff.Reset() prevRevision = nextRevision } else { - log.Printf("Processing watch for %s interrupted, retry in %s", c.targetPrefix, backoff.NextWait()) + c.logger.Printf("Processing watch for %s interrupted, retry in %s", c.targetPrefix, backoff.NextWait()) backoff.Wait(c.closeCtx) } } @@ -848,11 +853,11 @@ func (c *GrpcClients) getGrpcTargets(ctx context.Context, client *EtcdClient, ta func (c *GrpcClients) EtcdKeyUpdated(client *EtcdClient, key string, data []byte, prevValue []byte) { var info GrpcTargetInformationEtcd if err := json.Unmarshal(data, &info); err != nil { - log.Printf("Could not decode GRPC target %s=%s: %s", key, string(data), err) + c.logger.Printf("Could not decode GRPC target %s=%s: %s", key, string(data), err) return } if err := info.CheckValid(); err != nil { - log.Printf("Received invalid GRPC target %s=%s: %s", key, string(data), err) + c.logger.Printf("Received invalid GRPC target %s=%s: %s", key, string(data), err) return } @@ -866,21 +871,21 @@ func (c *GrpcClients) EtcdKeyUpdated(client *EtcdClient, key string, data []byte } if _, found := c.clientsMap[info.Address]; found { - log.Printf("GRPC target %s already exists, ignoring %s", info.Address, key) + c.logger.Printf("GRPC target %s already exists, ignoring %s", info.Address, key) return } opts := c.dialOptions.Load().([]grpc.DialOption) - cl, err := NewGrpcClient(info.Address, nil, opts...) + cl, err := NewGrpcClient(c.logger, info.Address, nil, opts...) if err != nil { - log.Printf("Could not create GRPC client for target %s: %s", info.Address, err) + c.logger.Printf("Could not create GRPC client for target %s: %s", info.Address, err) return } c.selfCheckWaitGroup.Add(1) go c.checkIsSelf(c.closeCtx, info.Address, cl) - log.Printf("Adding %s as GRPC target", cl.Target()) + c.logger.Printf("Adding %s as GRPC target", cl.Target()) if c.clientsMap == nil { c.clientsMap = make(map[string]*grpcClientsList) @@ -905,7 +910,7 @@ func (c *GrpcClients) EtcdKeyDeleted(client *EtcdClient, key string, prevValue [ func (c *GrpcClients) removeEtcdClientLocked(key string) { info, found := c.targetInformation[key] if !found { - log.Printf("No connection found for %s, ignoring", key) + c.logger.Printf("No connection found for %s, ignoring", key) c.wakeupForTesting() return } @@ -917,7 +922,7 @@ func (c *GrpcClients) removeEtcdClientLocked(key string) { } for _, client := range entry.clients { - log.Printf("Removing connection to %s (from %s)", client.Target(), key) + c.logger.Printf("Removing connection to %s (from %s)", client.Target(), key) c.closeClient(client) } delete(c.clientsMap, info.Address) @@ -951,7 +956,7 @@ func (c *GrpcClients) wakeupForTesting() { func (c *GrpcClients) Reload(config *goconf.ConfigFile) { if err := c.load(config, true); err != nil { - log.Printf("Could not reload RPC clients: %s", err) + c.logger.Printf("Could not reload RPC clients: %s", err) } } @@ -962,7 +967,7 @@ func (c *GrpcClients) Close() { for _, entry := range c.clientsMap { for _, client := range entry.clients { if err := client.Close(); err != nil { - log.Printf("Error closing client to %s: %s", client.Target(), err) + c.logger.Printf("Error closing client to %s: %s", client.Target(), err) } } diff --git a/grpc_client_test.go b/grpc_client_test.go index 329cd37..031c1ba 100644 --- a/grpc_client_test.go +++ b/grpc_client_test.go @@ -53,7 +53,9 @@ func (c *GrpcClients) getWakeupChannelForTesting() <-chan struct{} { func NewGrpcClientsForTestWithConfig(t *testing.T, config *goconf.ConfigFile, etcdClient *EtcdClient) (*GrpcClients, *DnsMonitor) { dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually - client, err := NewGrpcClients(config, etcdClient, dnsMonitor, "0.0.0") + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + client, err := NewGrpcClients(ctx, config, etcdClient, dnsMonitor, "0.0.0") require.NoError(t, err) t.Cleanup(func() { client.Close() @@ -77,7 +79,8 @@ func NewGrpcClientsWithEtcdForTest(t *testing.T, etcd *embed.Etcd) (*GrpcClients config.AddOption("grpc", "targettype", "etcd") config.AddOption("grpc", "targetprefix", "/grpctargets") - etcdClient, err := NewEtcdClient(config, "") + logger := NewLoggerForTest(t) + etcdClient, err := NewEtcdClient(logger, config, "") require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, etcdClient.Close()) @@ -108,7 +111,8 @@ func waitForEvent(ctx context.Context, t *testing.T, ch <-chan struct{}) { } func Test_GrpcClients_EtcdInitial(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) ensureNoGoroutinesLeak(t, func(t *testing.T) { _, addr1 := NewGrpcServerForTest(t) _, addr2 := NewGrpcServerForTest(t) @@ -119,7 +123,7 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) { SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) client, _ := NewGrpcClientsWithEtcdForTest(t, etcd) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() require.NoError(t, client.WaitForInitialized(ctx)) @@ -130,13 +134,14 @@ func Test_GrpcClients_EtcdInitial(t *testing.T) { func Test_GrpcClients_EtcdUpdate(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) etcd := NewEtcdForTest(t) client, _ := NewGrpcClientsWithEtcdForTest(t, etcd) ch := client.getWakeupChannelForTesting() - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() assert.Empty(client.GetClients()) @@ -176,13 +181,14 @@ func Test_GrpcClients_EtcdUpdate(t *testing.T) { func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) etcd := NewEtcdForTest(t) client, _ := NewGrpcClientsWithEtcdForTest(t, etcd) ch := client.getWakeupChannelForTesting() - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() assert.Empty(client.GetClients()) @@ -214,7 +220,8 @@ func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) { } func Test_GrpcClients_DnsDiscovery(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) ensureNoGoroutinesLeak(t, func(t *testing.T) { assert := assert.New(t) require := require.New(t) @@ -228,7 +235,7 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) { client, dnsMonitor := NewGrpcClientsForTest(t, target) ch := client.getWakeupChannelForTesting() - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() // Wait for initial check to be done to make sure internal dnsmonitor goroutine is waiting. @@ -268,7 +275,6 @@ func Test_GrpcClients_DnsDiscovery(t *testing.T) { } func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { - CatchLogForTest(t) assert := assert.New(t) lookup := newMockDnsLookupForTest(t) target := "testgrpc:12345" @@ -298,7 +304,6 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) { } func Test_GrpcClients_Encryption(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { require := require.New(t) serverKey, err := rsa.GenerateKey(rand.Reader, 1024) diff --git a/grpc_common.go b/grpc_common.go index b7df93e..dfea756 100644 --- a/grpc_common.go +++ b/grpc_common.go @@ -25,7 +25,6 @@ import ( "context" "crypto/tls" "fmt" - "log" "net" "github.com/dlintw/goconf" @@ -134,7 +133,7 @@ func (c *reloadableCredentials) Close() { } } -func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) { +func NewReloadableCredentials(logger Logger, config *goconf.ConfigFile, server bool) (credentials.TransportCredentials, error) { var prefix string var caPrefix string if server { @@ -153,7 +152,7 @@ func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentia var loader *CertificateReloader var err error if certificateFile != "" && keyFile != "" { - loader, err = NewCertificateReloader(certificateFile, keyFile) + loader, err = NewCertificateReloader(logger, certificateFile, keyFile) if err != nil { return nil, fmt.Errorf("invalid GRPC %s certificate / key in %s / %s: %w", prefix, certificateFile, keyFile, err) } @@ -161,7 +160,7 @@ func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentia var pool *CertPoolReloader if caFile != "" { - pool, err = NewCertPoolReloader(caFile) + pool, err = NewCertPoolReloader(logger, caFile) if err != nil { return nil, err } @@ -173,9 +172,9 @@ func NewReloadableCredentials(config *goconf.ConfigFile, server bool) (credentia if loader == nil && pool == nil { if server { - log.Printf("WARNING: No GRPC server certificate and/or key configured, running unencrypted") + logger.Printf("WARNING: No GRPC server certificate and/or key configured, running unencrypted") } else { - log.Printf("WARNING: No GRPC CA configured, expecting unencrypted connections") + logger.Printf("WARNING: No GRPC CA configured, expecting unencrypted connections") } return insecure.NewCredentials(), nil } diff --git a/grpc_remote_client.go b/grpc_remote_client.go index 8940fde..02133ec 100644 --- a/grpc_remote_client.go +++ b/grpc_remote_client.go @@ -27,7 +27,6 @@ import ( "errors" "fmt" "io" - "log" "sync/atomic" "google.golang.org/grpc/codes" @@ -49,6 +48,7 @@ func getMD(md metadata.MD, key string) string { // remoteGrpcClient is a remote client connecting from a GRPC proxy to a Hub. type remoteGrpcClient struct { + logger Logger hub *Hub client RpcSessions_ProxySessionServer @@ -73,6 +73,7 @@ func newRemoteGrpcClient(hub *Hub, request RpcSessions_ProxySessionServer) (*rem closeCtx, closeFunc := context.WithCancelCause(context.Background()) result := &remoteGrpcClient{ + logger: hub.logger, hub: hub, client: request, @@ -105,7 +106,7 @@ func (c *remoteGrpcClient) readPump() { } if status.Code(err) != codes.Canceled { - log.Printf("Error reading from remote client for session %s: %s", c.sessionId, err) + c.logger.Printf("Error reading from remote client for session %s: %s", c.sessionId, err) closeError = err } break @@ -193,7 +194,7 @@ func (c *remoteGrpcClient) SendMessage(message WritableClientMessage) bool { case c.messages <- message: return true default: - log.Printf("Message queue for remote client of session %s is full, not sending %+v", c.sessionId, message) + c.logger.Printf("Message queue for remote client of session %s is full, not sending %+v", c.sessionId, message) return false } } @@ -215,7 +216,7 @@ func (c *remoteGrpcClient) run() error { case msg := <-c.messages: data, err := json.Marshal(msg) if err != nil { - log.Printf("Error marshalling %+v for remote client for session %s: %s", msg, c.sessionId, err) + c.logger.Printf("Error marshalling %+v for remote client for session %s: %s", msg, c.sessionId, err) continue } diff --git a/grpc_server.go b/grpc_server.go index 9639e16..c7863c4 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -27,7 +27,6 @@ import ( "encoding/hex" "errors" "fmt" - "log" "net" "net/url" "os" @@ -73,6 +72,7 @@ type GrpcServer struct { UnimplementedRpcMcuServer UnimplementedRpcSessionsServer + logger Logger version string creds credentials.TransportCredentials conn *grpc.Server @@ -82,7 +82,7 @@ type GrpcServer struct { hub GrpcServerHub } -func NewGrpcServer(config *goconf.ConfigFile, version string) (*GrpcServer, error) { +func NewGrpcServer(ctx context.Context, config *goconf.ConfigFile, version string) (*GrpcServer, error) { var listener net.Listener if addr, _ := GetStringOptionWithEnv(config, "grpc", "listen"); addr != "" { var err error @@ -92,13 +92,15 @@ func NewGrpcServer(config *goconf.ConfigFile, version string) (*GrpcServer, erro } } - creds, err := NewReloadableCredentials(config, true) + logger := LoggerFromContext(ctx) + creds, err := NewReloadableCredentials(logger, config, true) if err != nil { return nil, err } conn := grpc.NewServer(grpc.Creds(creds)) result := &GrpcServer{ + logger: logger, version: version, creds: creds, conn: conn, @@ -130,7 +132,7 @@ func (s *GrpcServer) Close() { func (s *GrpcServer) LookupResumeId(ctx context.Context, request *LookupResumeIdRequest) (*LookupResumeIdReply, error) { statsGrpcServerCalls.WithLabelValues("LookupResumeId").Inc() // TODO: Remove debug logging - log.Printf("Lookup session for resume id %s", request.ResumeId) + s.logger.Printf("Lookup session for resume id %s", request.ResumeId) session := s.hub.GetSessionByResumeId(PrivateSessionId(request.ResumeId)) if session == nil { return nil, status.Error(codes.NotFound, "no such room session id") @@ -144,7 +146,7 @@ func (s *GrpcServer) LookupResumeId(ctx context.Context, request *LookupResumeId func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) { statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc() // TODO: Remove debug logging - log.Printf("Lookup session id for room session id %s", request.RoomSessionId) + s.logger.Printf("Lookup session id for room session id %s", request.RoomSessionId) sid, err := s.hub.GetSessionIdByRoomSessionId(RoomSessionId(request.RoomSessionId)) if errors.Is(err, ErrNoSuchRoomSession) { return nil, status.Error(codes.NotFound, "no such room session id") @@ -154,7 +156,7 @@ func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSession if sid != "" && request.DisconnectReason != "" { if session := s.hub.GetSessionByPublicId(PublicSessionId(sid)); session != nil { - log.Printf("Closing session %s because same room session %s connected", session.PublicId(), request.RoomSessionId) + s.logger.Printf("Closing session %s because same room session %s connected", session.PublicId(), request.RoomSessionId) session.LeaveRoom(false) switch sess := session.(type) { case *ClientSession: @@ -173,7 +175,7 @@ func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSession func (s *GrpcServer) IsSessionInCall(ctx context.Context, request *IsSessionInCallRequest) (*IsSessionInCallReply, error) { statsGrpcServerCalls.WithLabelValues("IsSessionInCall").Inc() // TODO: Remove debug logging - log.Printf("Check if session %s is in call %s on %s", request.SessionId, request.RoomId, request.BackendUrl) + s.logger.Printf("Check if session %s is in call %s on %s", request.SessionId, request.RoomId, request.BackendUrl) session := s.hub.GetSessionByPublicId(PublicSessionId(request.SessionId)) if session == nil { return nil, status.Error(codes.NotFound, "no such session id") @@ -194,7 +196,7 @@ func (s *GrpcServer) IsSessionInCall(ctx context.Context, request *IsSessionInCa func (s *GrpcServer) GetInternalSessions(ctx context.Context, request *GetInternalSessionsRequest) (*GetInternalSessionsReply, error) { statsGrpcServerCalls.WithLabelValues("GetInternalSessions").Inc() // TODO: Remove debug logging - log.Printf("Get internal sessions from %s on %v (fallback %s)", request.RoomId, request.BackendUrls, request.BackendUrl) + s.logger.Printf("Get internal sessions from %s on %v (fallback %s)", request.RoomId, request.BackendUrls, request.BackendUrl) var backendUrls []string if len(request.BackendUrls) > 0 { @@ -259,7 +261,7 @@ func (s *GrpcServer) GetInternalSessions(ctx context.Context, request *GetIntern func (s *GrpcServer) GetPublisherId(ctx context.Context, request *GetPublisherIdRequest) (*GetPublisherIdReply, error) { statsGrpcServerCalls.WithLabelValues("GetPublisherId").Inc() // TODO: Remove debug logging - log.Printf("Get %s publisher id for session %s", request.StreamType, request.SessionId) + s.logger.Printf("Get %s publisher id for session %s", request.StreamType, request.SessionId) session := s.hub.GetSessionByPublicId(PublicSessionId(request.SessionId)) if session == nil { return nil, status.Error(codes.NotFound, "no such session") @@ -281,11 +283,11 @@ func (s *GrpcServer) GetPublisherId(ctx context.Context, request *GetPublisherId } var err error if reply.ConnectToken, err = s.hub.CreateProxyToken(""); err != nil && !errors.Is(err, ErrNoProxyMcu) { - log.Printf("Error creating proxy token for connection: %s", err) + s.logger.Printf("Error creating proxy token for connection: %s", err) return nil, status.Error(codes.Internal, "error creating proxy connect token") } if reply.PublisherToken, err = s.hub.CreateProxyToken(publisher.Id()); err != nil && !errors.Is(err, ErrNoProxyMcu) { - log.Printf("Error creating proxy token for publisher %s: %s", publisher.Id(), err) + s.logger.Printf("Error creating proxy token for publisher %s: %s", publisher.Id(), err) return nil, status.Error(codes.Internal, "error creating proxy publisher token") } return reply, nil diff --git a/grpc_server_test.go b/grpc_server_test.go index 8ebdc7c..87a036a 100644 --- a/grpc_server_test.go +++ b/grpc_server_test.go @@ -62,11 +62,13 @@ func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context, counter uint64) } func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) { + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) for port := 50000; port < 50100; port++ { addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) config.AddOption("grpc", "listen", addr) var err error - server, err = NewGrpcServer(config, "0.0.0") + server, err = NewGrpcServer(ctx, config, "0.0.0") if isErrorAddressAlreadyInUse(err) { continue } @@ -96,7 +98,6 @@ func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) { } func Test_GrpcServer_ReloadCerts(t *testing.T) { - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) key, err := rsa.GenerateKey(rand.Reader, 1024) @@ -167,7 +168,7 @@ func Test_GrpcServer_ReloadCerts(t *testing.T) { } func Test_GrpcServer_ReloadCA(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) require := require.New(t) serverKey, err := rsa.GenerateKey(rand.Reader, 1024) require.NoError(err) @@ -211,7 +212,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { RootCAs: pool, Certificates: []tls.Certificate{pair1}, } - client1, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1))) + client1, err := NewGrpcClient(logger, addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1))) require.NoError(err) defer client1.Close() // nolint @@ -237,7 +238,7 @@ func Test_GrpcServer_ReloadCA(t *testing.T) { RootCAs: pool, Certificates: []tls.Certificate{pair2}, } - client2, err := NewGrpcClient(addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2))) + client2, err := NewGrpcClient(logger, addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2))) require.NoError(err) defer client2.Close() // nolint diff --git a/hub.go b/hub.go index b70eded..75a683a 100644 --- a/hub.go +++ b/hub.go @@ -36,7 +36,6 @@ import ( "errors" "fmt" "hash/fnv" - "log" "net" "net/http" "net/url" @@ -140,6 +139,7 @@ func init() { type Hub struct { version string + logger Logger events AsyncEvents upgrader websocket.Upgrader cookie *SessionIdCodec @@ -219,13 +219,14 @@ type Hub struct { blockedCandidates atomic.Pointer[AllowedIps] } -func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer, rpcClients *GrpcClients, etcdClient *EtcdClient, r *mux.Router, version string) (*Hub, error) { +func NewHub(ctx context.Context, config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer, rpcClients *GrpcClients, etcdClient *EtcdClient, r *mux.Router, version string) (*Hub, error) { + logger := LoggerFromContext(ctx) hashKey, _ := GetStringOptionWithEnv(config, "sessions", "hashkey") switch len(hashKey) { case 32: case 64: default: - log.Printf("WARNING: The sessions hash key should be 32 or 64 bytes but is %d bytes", len(hashKey)) + logger.Printf("WARNING: The sessions hash key should be 32 or 64 bytes but is %d bytes", len(hashKey)) } blockKey, _ := GetStringOptionWithEnv(config, "sessions", "blockkey") @@ -242,7 +243,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer internalClientsSecret, _ := GetStringOptionWithEnv(config, "clients", "internalsecret") if internalClientsSecret == "" { - log.Println("WARNING: No shared secret has been set for internal clients.") + logger.Println("WARNING: No shared secret has been set for internal clients.") } maxConcurrentRequestsPerHost, _ := config.GetInt("backend", "connectionsperhost") @@ -250,18 +251,18 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer maxConcurrentRequestsPerHost = defaultMaxConcurrentRequestsPerHost } - backend, err := NewBackendClient(config, maxConcurrentRequestsPerHost, version, etcdClient) + backend, err := NewBackendClient(ctx, config, maxConcurrentRequestsPerHost, version, etcdClient) if err != nil { return nil, err } - log.Printf("Using a maximum of %d concurrent backend connections per host", maxConcurrentRequestsPerHost) + logger.Printf("Using a maximum of %d concurrent backend connections per host", maxConcurrentRequestsPerHost) backendTimeoutSeconds, _ := config.GetInt("backend", "timeout") if backendTimeoutSeconds <= 0 { backendTimeoutSeconds = defaultBackendTimeoutSeconds } backendTimeout := time.Duration(backendTimeoutSeconds) * time.Second - log.Printf("Using a timeout of %s for backend connections", backendTimeout) + logger.Printf("Using a timeout of %s for backend connections", backendTimeout) mcuTimeoutSeconds, _ := config.GetInt("mcu", "timeout") if mcuTimeoutSeconds <= 0 { @@ -271,7 +272,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer allowSubscribeAnyStream, _ := config.GetBool("app", "allowsubscribeany") if allowSubscribeAnyStream { - log.Printf("WARNING: Allow subscribing any streams, this is insecure and should only be enabled for testing") + logger.Printf("WARNING: Allow subscribing any streams, this is insecure and should only be enabled for testing") } trustedProxies, _ := config.GetString("app", "trustedproxies") @@ -282,7 +283,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer skipFederationVerify, _ := config.GetBool("federation", "skipverify") if skipFederationVerify { - log.Println("WARNING: Federation target verification is disabled!") + logger.Println("WARNING: Federation target verification is disabled!") } federationTimeoutSeconds, _ := config.GetInt("federation", "timeout") if federationTimeoutSeconds <= 0 { @@ -291,10 +292,10 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer federationTimeout := time.Duration(federationTimeoutSeconds) * time.Second if !trustedProxiesIps.Empty() { - log.Printf("Trusted proxies: %s", trustedProxiesIps) + logger.Printf("Trusted proxies: %s", trustedProxiesIps) } else { trustedProxiesIps = DefaultTrustedProxies - log.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) + logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } decodeCaches := make([]*LruCache[*SessionIdData], 0, numDecodeCaches) @@ -325,20 +326,20 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer var geoip *GeoLookup if geoipUrl != "" { if geoipUrl, found := strings.CutPrefix(geoipUrl, "file://"); found { - log.Printf("Using GeoIP database from %s", geoipUrl) - geoip, err = NewGeoLookupFromFile(geoipUrl) + logger.Printf("Using GeoIP database from %s", geoipUrl) + geoip, err = NewGeoLookupFromFile(logger, geoipUrl) } else { - log.Printf("Downloading GeoIP database from %s", geoipUrl) - geoip, err = NewGeoLookupFromUrl(geoipUrl) + logger.Printf("Downloading GeoIP database from %s", geoipUrl) + geoip, err = NewGeoLookupFromUrl(logger, geoipUrl) } if err != nil { return nil, err } } else { - log.Printf("Not using GeoIP database") + logger.Printf("Not using GeoIP database") } - geoipOverrides, err := LoadGeoIPOverrides(config, false) + geoipOverrides, err := LoadGeoIPOverrides(ctx, config, false) if err != nil { return nil, err } @@ -350,6 +351,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer hub := &Hub{ version: version, + logger: logger, events: events, upgrader: websocket.Upgrader{ ReadBufferSize: websocketReadBufferSize, @@ -414,10 +416,10 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer return nil, fmt.Errorf("invalid allowedcandidates: %w", err) } - log.Printf("Candidates allowlist: %s", allowed) + logger.Printf("Candidates allowlist: %s", allowed) hub.allowedCandidates.Store(allowed) } else { - log.Printf("No candidates allowlist") + logger.Printf("No candidates allowlist") } if value, _ := config.GetString("mcu", "blockedcandidates"); value != "" { blocked, err := ParseAllowedIps(value) @@ -425,10 +427,10 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer return nil, fmt.Errorf("invalid blockedcandidates: %w", err) } - log.Printf("Candidates blocklist: %s", blocked) + logger.Printf("Candidates blocklist: %s", blocked) hub.blockedCandidates.Store(blocked) } else { - log.Printf("No candidates blocklist") + logger.Printf("No candidates blocklist") } hub.trustedProxies.Store(trustedProxiesIps) @@ -469,7 +471,7 @@ func (h *Hub) SetMcu(mcu Mcu) { welcome.Welcome.RemoveFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) } else { - log.Printf("Using a timeout of %s for MCU requests", h.mcuTimeout) + h.logger.Printf("Using a timeout of %s for MCU requests", h.mcuTimeout) h.info.AddFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) h.infoInternal.AddFeature(ServerFeatureMcu, ServerFeatureSimulcast, ServerFeatureUpdateSdp) @@ -504,7 +506,7 @@ func (h *Hub) updateGeoDatabase() { defer h.geoipUpdating.Store(false) backoff, err := NewExponentialBackoff(time.Second, 5*time.Minute) if err != nil { - log.Printf("Could not create exponential backoff: %s", err) + h.logger.Printf("Could not create exponential backoff: %s", err) return } @@ -514,7 +516,7 @@ func (h *Hub) updateGeoDatabase() { break } - log.Printf("Could not update GeoIP database, will retry in %s (%s)", backoff.NextWait(), err) + h.logger.Printf("Could not update GeoIP database, will retry in %s (%s)", backoff.NextWait(), err) backoff.Wait(context.Background()) } } @@ -562,21 +564,21 @@ func (h *Hub) Stop() { h.throttler.Close() } -func (h *Hub) Reload(config *goconf.ConfigFile) { +func (h *Hub) Reload(ctx context.Context, config *goconf.ConfigFile) { trustedProxies, _ := config.GetString("app", "trustedproxies") if trustedProxiesIps, err := ParseAllowedIps(trustedProxies); err == nil { if !trustedProxiesIps.Empty() { - log.Printf("Trusted proxies: %s", trustedProxiesIps) + h.logger.Printf("Trusted proxies: %s", trustedProxiesIps) } else { trustedProxiesIps = DefaultTrustedProxies - log.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) + h.logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } h.trustedProxies.Store(trustedProxiesIps) } else { - log.Printf("Error parsing trusted proxies from \"%s\": %s", trustedProxies, err) + h.logger.Printf("Error parsing trusted proxies from \"%s\": %s", trustedProxies, err) } - geoipOverrides, _ := LoadGeoIPOverrides(config, true) + geoipOverrides, _ := LoadGeoIPOverrides(ctx, config, true) if len(geoipOverrides) > 0 { h.geoipOverrides.Store(&geoipOverrides) } else { @@ -585,24 +587,24 @@ func (h *Hub) Reload(config *goconf.ConfigFile) { if value, _ := config.GetString("mcu", "allowedcandidates"); value != "" { if allowed, err := ParseAllowedIps(value); err != nil { - log.Printf("invalid allowedcandidates: %s", err) + h.logger.Printf("invalid allowedcandidates: %s", err) } else { - log.Printf("Candidates allowlist: %s", allowed) + h.logger.Printf("Candidates allowlist: %s", allowed) h.allowedCandidates.Store(allowed) } } else { - log.Printf("No candidates allowlist") + h.logger.Printf("No candidates allowlist") h.allowedCandidates.Store(nil) } if value, _ := config.GetString("mcu", "blockedcandidates"); value != "" { if blocked, err := ParseAllowedIps(value); err != nil { - log.Printf("invalid blockedcandidates: %s", err) + h.logger.Printf("invalid blockedcandidates: %s", err) } else { - log.Printf("Candidates blocklist: %s", blocked) + h.logger.Printf("Candidates blocklist: %s", blocked) h.blockedCandidates.Store(blocked) } } else { - log.Printf("No candidates blocklist") + h.logger.Printf("No candidates blocklist") h.blockedCandidates.Store(nil) } @@ -769,7 +771,7 @@ func (h *Hub) checkExpiredSessions(now time.Time) { for session, expires := range h.expiredSessions { if now.After(expires) { h.mu.Unlock() - log.Printf("Closing expired session %s (private=%s)", session.PublicId(), session.PrivateId()) + h.logger.Printf("Closing expired session %s (private=%s)", session.PublicId(), session.PrivateId()) session.Close() h.mu.Lock() // Should already be deleted by the close code, but better be sure. @@ -961,7 +963,7 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend * client, ok := c.(*Client) if !ok { - log.Printf("Can't register non-client %T", c) + h.logger.Printf("Can't register non-client %T", c) client.SendMessage(message.NewWrappedErrorServerMessage(errors.New("can't register non-client"))) return } @@ -980,11 +982,11 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend * userId := auth.Auth.UserId if userId != "" { - log.Printf("Register user %s@%s from %s in %s (%s) %s (private=%s)", userId, backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) + h.logger.Printf("Register user %s@%s from %s in %s (%s) %s (private=%s)", userId, backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) } else if message.Hello.Auth.Type != HelloClientTypeClient { - log.Printf("Register %s@%s from %s in %s (%s) %s (private=%s)", message.Hello.Auth.Type, backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) + h.logger.Printf("Register %s@%s from %s in %s (%s) %s (private=%s)", message.Hello.Auth.Type, backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) } else { - log.Printf("Register anonymous@%s from %s in %s (%s) %s (private=%s)", backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) + h.logger.Printf("Register anonymous@%s from %s in %s (%s) %s (private=%s)", backend.Id(), client.RemoteAddr(), client.Country(), client.UserAgent(), publicSessionId, privateSessionId) } session, err := NewClientSession(h, privateSessionId, publicSessionId, sessionIdData, backend, message.Hello, auth.Auth) @@ -994,7 +996,7 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend * } if err := backend.AddSession(session); err != nil { - log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), err) + h.logger.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), err) session.Close() client.SendMessage(message.NewWrappedErrorServerMessage(err)) return @@ -1013,12 +1015,12 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend * count, err := c.GetSessionCount(ctx, session.BackendUrl()) if err != nil { - log.Printf("Received error while getting session count for %s from %s: %s", session.BackendUrl(), c.Target(), err) + h.logger.Printf("Received error while getting session count for %s from %s: %s", session.BackendUrl(), c.Target(), err) return } if count > 0 { - log.Printf("%d sessions connected for %s on %s", count, session.BackendUrl(), c.Target()) + h.logger.Printf("%d sessions connected for %s on %s", count, session.BackendUrl(), c.Target()) totalCount.Add(count) } }(client) @@ -1026,7 +1028,7 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend * wg.Wait() if totalCount.Load() > limit { backend.RemoveSession(session) - log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), SessionLimitExceeded) + h.logger.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), SessionLimitExceeded) session.Close() client.SendMessage(message.NewWrappedErrorServerMessage(SessionLimitExceeded)) return @@ -1078,7 +1080,7 @@ func (h *Hub) processUnregister(client HandlerClient) Session { } h.mu.Unlock() if session != nil { - log.Printf("Unregister %s (private=%s)", session.PublicId(), session.PrivateId()) + h.logger.Printf("Unregister %s (private=%s)", session.PublicId(), session.PrivateId()) if c, ok := client.(*Client); ok { if cs, ok := session.(*ClientSession); ok { cs.ClearClient(c) @@ -1094,10 +1096,10 @@ func (h *Hub) processMessage(client HandlerClient, data []byte) { var message ClientMessage if err := message.UnmarshalJSON(data); err != nil { if session := client.GetSession(); session != nil { - log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) + h.logger.Printf("Error decoding message from client %s: %v", session.PublicId(), err) session.SendError(InvalidFormat) } else { - log.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) + h.logger.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) client.SendError(InvalidFormat) } return @@ -1105,14 +1107,14 @@ func (h *Hub) processMessage(client HandlerClient, data []byte) { if err := message.CheckValid(); err != nil { if session := client.GetSession(); session != nil { - log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + h.logger.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) if err, ok := err.(*Error); ok { session.SendMessage(message.NewErrorServerMessage(err)) } else { session.SendMessage(message.NewErrorServerMessage(InvalidFormat)) } } else { - log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) + h.logger.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) if err, ok := err.(*Error); ok { client.SendMessage(message.NewErrorServerMessage(err)) } else { @@ -1161,9 +1163,9 @@ func (h *Hub) processMessage(client HandlerClient, data []byte) { case "bye": h.processByeMsg(client, &message) case "hello": - log.Printf("Ignore hello %+v for already authenticated connection %s", message.Hello, session.PublicId()) + h.logger.Printf("Ignore hello %+v for already authenticated connection %s", message.Hello, session.PublicId()) default: - log.Printf("Ignore unknown message %+v from %s", message, session.PublicId()) + h.logger.Printf("Ignore unknown message %+v from %s", message, session.PublicId()) } } @@ -1220,7 +1222,7 @@ func (h *Hub) tryProxyResume(c HandlerClient, resumeId PrivateSessionId, message response, err := client.LookupResumeId(ctx, resumeId) if err != nil { - log.Printf("Could not lookup resume id %s on %s: %s", resumeId, client.Target(), err) + h.logger.Printf("Could not lookup resume id %s on %s: %s", resumeId, client.Target(), err) return } @@ -1245,17 +1247,17 @@ func (h *Hub) tryProxyResume(c HandlerClient, resumeId PrivateSessionId, message rs, err := NewRemoteSession(h, client, info.client, PublicSessionId(info.response.SessionId)) if err != nil { - log.Printf("Could not create remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err) + h.logger.Printf("Could not create remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err) return false } if err := rs.Start(message); err != nil { rs.Close() - log.Printf("Could not start remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err) + h.logger.Printf("Could not start remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err) return false } - log.Printf("Proxy session %s to %s", info.response.SessionId, info.client.Target()) + h.logger.Printf("Proxy session %s to %s", info.response.SessionId, info.client.Target()) h.mu.Lock() defer h.mu.Unlock() h.remoteSessions[rs] = true @@ -1264,7 +1266,7 @@ func (h *Hub) tryProxyResume(c HandlerClient, resumeId PrivateSessionId, message } func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { - ctx := context.TODO() + ctx := NewLoggerContext(client.Context(), h.logger) resumeId := message.Hello.ResumeId if resumeId != "" { throttle, err := h.throttler.CheckBruteforce(ctx, client.RemoteAddr(), "HelloResume") @@ -1272,7 +1274,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { client.SendMessage(message.NewErrorServerMessage(TooManyRequests)) return } else if err != nil { - log.Printf("Error checking for bruteforce: %s", err) + h.logger.Printf("Error checking for bruteforce: %s", err) client.SendMessage(message.NewWrappedErrorServerMessage(err)) return } @@ -1307,7 +1309,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { if !ok { // Should never happen as clients only can resume their own sessions. h.mu.Unlock() - log.Printf("Client resumed non-client session %s (private=%s)", session.PublicId(), session.PrivateId()) + h.logger.Printf("Client resumed non-client session %s (private=%s)", session.PublicId(), session.PrivateId()) statsHubSessionResumeFailed.Inc() client.SendMessage(message.NewErrorServerMessage(NoSuchSession)) return @@ -1320,7 +1322,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { } if prev := clientSession.SetClient(client); prev != nil { - log.Printf("Closing previous client from %s for session %s", prev.RemoteAddr(), session.PublicId()) + h.logger.Printf("Closing previous client from %s for session %s", prev.RemoteAddr(), session.PublicId()) prev.SendByeResponseWithReason(nil, "session_resumed") } @@ -1329,7 +1331,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { delete(h.expectHelloClients, client) h.mu.Unlock() - log.Printf("Resume session from %s in %s (%s) %s (private=%s)", client.RemoteAddr(), client.Country(), client.UserAgent(), session.PublicId(), session.PrivateId()) + h.logger.Printf("Resume session from %s in %s (%s) %s (private=%s)", client.RemoteAddr(), client.Country(), client.UserAgent(), session.PublicId(), session.PrivateId()) statsHubSessionsResumedTotal.WithLabelValues(clientSession.Backend().Id(), string(clientSession.ClientType())).Inc() h.sendHelloResponse(clientSession, message) @@ -1437,7 +1439,7 @@ func (h *Hub) processHelloV2(ctx context.Context, client HandlerClient, message return jwt.ParseEdPublicKeyFromPEM(data) } default: - log.Printf("Unexpected signing method: %v", token.Header["alg"]) + h.logger.Printf("Unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } @@ -1570,13 +1572,13 @@ func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage) return } - ctx := context.TODO() + ctx := NewLoggerContext(client.Context(), h.logger) throttle, err := h.throttler.CheckBruteforce(ctx, client.RemoteAddr(), "HelloInternal") if err == ErrBruteforceDetected { client.SendMessage(message.NewErrorServerMessage(TooManyRequests)) return } else if err != nil { - log.Printf("Error checking for bruteforce: %s", err) + h.logger.Printf("Error checking for bruteforce: %s", err) client.SendMessage(message.NewWrappedErrorServerMessage(err)) return } @@ -1611,7 +1613,7 @@ func (h *Hub) disconnectByRoomSessionId(ctx context.Context, roomSessionId RoomS if err == ErrNoSuchRoomSession { return } else if err != nil { - log.Printf("Could not get session id for room session %s: %s", roomSessionId, err) + h.logger.Printf("Could not get session id for room session %s: %s", roomSessionId, err) return } @@ -1629,12 +1631,12 @@ func (h *Hub) disconnectByRoomSessionId(ctx context.Context, roomSessionId RoomS }, } if err := h.events.PublishSessionMessage(sessionId, backend, msg); err != nil { - log.Printf("Could not send reconnect bye to session %s: %s", sessionId, err) + h.logger.Printf("Could not send reconnect bye to session %s: %s", sessionId, err) } return } - log.Printf("Closing session %s because same room session %s connected", session.PublicId(), roomSessionId) + h.logger.Printf("Closing session %s because same room session %s connected", session.PublicId(), roomSessionId) session.LeaveRoom(false) switch sess := session.(type) { case *ClientSession: @@ -1787,7 +1789,7 @@ func (h *Hub) processRoom(sess Session, message *ClientMessage) { } } - log.Printf("Error creating federation client to %s for %s to join room %s: %s", federation.SignalingUrl, session.PublicId(), roomId, err) + h.logger.Printf("Error creating federation client to %s for %s to join room %s: %s", federation.SignalingUrl, session.PublicId(), roomId, err) session.SendMessage(message.NewErrorServerMessage( NewErrorDetail("federation_error", "Failed to create federation client.", details), )) @@ -1799,14 +1801,14 @@ func (h *Hub) processRoom(sess Session, message *ClientMessage) { roomSessionId := message.Room.SessionId if roomSessionId == "" { // TODO(jojo): Better make the session id required in the request. - log.Printf("User did not send a room session id, assuming session %s", session.PublicId()) + h.logger.Printf("User did not send a room session id, assuming session %s", session.PublicId()) roomSessionId = RoomSessionId(session.PublicId()) } // Prefix room session id to allow using the same signaling server for two Nextcloud instances during development. // Otherwise the same room session id will be detected and the other session will be kicked. if err := session.UpdateRoomSessionId(FederatedRoomSessionIdPrefix + roomSessionId); err != nil { - log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err) + h.logger.Printf("Error updating room session id for session %s: %s", session.PublicId(), err) } h.mu.Lock() @@ -1821,12 +1823,12 @@ func (h *Hub) processRoom(sess Session, message *ClientMessage) { roomSessionId := message.Room.SessionId if roomSessionId == "" { // TODO(jojo): Better make the session id required in the request. - log.Printf("User did not send a room session id, assuming session %s", session.PublicId()) + h.logger.Printf("User did not send a room session id, assuming session %s", session.PublicId()) roomSessionId = RoomSessionId(session.PublicId()) } if err := session.UpdateRoomSessionId(roomSessionId); err != nil { - log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err) + h.logger.Printf("Error updating room session id for session %s: %s", session.PublicId(), err) } session.SendMessage(message.NewErrorServerMessage( NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{ @@ -1856,7 +1858,7 @@ func (h *Hub) processRoom(sess Session, message *ClientMessage) { sessionId := message.Room.SessionId if sessionId == "" { // TODO(jojo): Better make the session id required in the request. - log.Printf("User did not send a room session id, assuming session %s", session.PublicId()) + h.logger.Printf("User did not send a room session id, assuming session %s", session.PublicId()) sessionId = RoomSessionId(session.PublicId()) } request := NewBackendClientRoomRequest(roomId, session.UserId(), sessionId) @@ -1938,17 +1940,18 @@ func (h *Hub) publishFederatedSessions() (int, *sync.WaitGroup) { return 0, &wg } count := 0 + ctx := NewLoggerContext(context.Background(), h.logger) for roomId, entries := range rooms { for u, e := range entries { wg.Add(1) count += len(e) go func(roomId string, url *url.URL, entries []BackendPingEntry) { defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) + sendCtx, cancel := context.WithTimeout(ctx, h.backendTimeout) defer cancel() - if err := h.roomPing.SendPings(ctx, roomId, url, entries); err != nil { - log.Printf("Error pinging room %s for active entries %+v: %s", roomId, entries, err) + if err := h.roomPing.SendPings(sendCtx, roomId, url, entries); err != nil { + h.logger.Printf("Error pinging room %s for active entries %+v: %s", roomId, entries, err) } }(roomId, urls[u], e) } @@ -2071,7 +2074,7 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { var data MessageClientMessageData if err := json.Unmarshal(msg.Data, &data); err == nil { if err := data.CheckValid(); err != nil { - log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + h.logger.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) if err, ok := err.(*Error); ok { session.SendMessage(message.NewErrorServerMessage(err)) } else { @@ -2119,7 +2122,7 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { return } - log.Printf("Closing screen publisher for %s", session.PublicId()) + h.logger.Printf("Closing screen publisher for %s", session.PublicId()) ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout) defer cancel() publisher.Close(ctx) @@ -2188,7 +2191,7 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { var data MessageClientMessageData if err := json.Unmarshal(msg.Data, &data); err == nil { if err := data.CheckValid(); err != nil { - log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + h.logger.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) if err, ok := err.(*Error); ok { session.SendMessage(message.NewErrorServerMessage(err)) } else { @@ -2204,7 +2207,7 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { } } if subject == "" { - log.Printf("Unknown recipient in message %+v from %s", msg, session.PublicId()) + h.logger.Printf("Unknown recipient in message %+v from %s", msg, session.PublicId()) return } @@ -2224,7 +2227,7 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { // The recipient is connected to this instance, no need to go through asynchronous events. if clientData != nil && clientData.Type == "sendoffer" { if err := session.IsAllowedToSend(clientData); err != nil { - log.Printf("Session %s is not allowed to send offer for %s, ignoring (%s)", session.PublicId(), clientData.RoomType, err) + h.logger.Printf("Session %s is not allowed to send offer for %s, ignoring (%s)", session.PublicId(), clientData.RoomType, err) sendNotAllowed(session, message, "Not allowed to send offer") return } @@ -2238,18 +2241,18 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { mc, err := recipient.GetOrCreateSubscriber(ctx, h.mcu, session.PublicId(), StreamType(clientData.RoomType)) if err != nil { - log.Printf("Could not create MCU subscriber for session %s to send %+v to %s: %s", session.PublicId(), clientData, recipient.PublicId(), err) + h.logger.Printf("Could not create MCU subscriber for session %s to send %+v to %s: %s", session.PublicId(), clientData, recipient.PublicId(), err) sendMcuClientNotFound(session, message) return } else if mc == nil { - log.Printf("No MCU subscriber found for session %s to send %+v to %s", session.PublicId(), clientData, recipient.PublicId()) + h.logger.Printf("No MCU subscriber found for session %s to send %+v to %s", session.PublicId(), clientData, recipient.PublicId()) sendMcuClientNotFound(session, message) return } mc.SendMessage(session.Context(), msg, clientData, func(err error, response api.StringMap) { if err != nil { - log.Printf("Could not send MCU message %+v for session %s to %s: %s", clientData, session.PublicId(), recipient.PublicId(), err) + h.logger.Printf("Could not send MCU message %+v for session %s to %s: %s", clientData, session.PublicId(), recipient.PublicId(), err) sendMcuProcessingFailed(session, message) return } else if response == nil { @@ -2270,7 +2273,7 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { } else { if clientData != nil && clientData.Type == "sendoffer" { if err := session.IsAllowedToSend(clientData); err != nil { - log.Printf("Session %s is not allowed to send offer for %s, ignoring (%s)", session.PublicId(), clientData.RoomType, err) + h.logger.Printf("Session %s is not allowed to send offer for %s, ignoring (%s)", session.PublicId(), clientData.RoomType, err) sendNotAllowed(session, message, "Not allowed to send offer") return } @@ -2284,7 +2287,7 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { }, } if err := h.events.PublishSessionMessage(recipientSessionId, session.Backend(), async); err != nil { - log.Printf("Error publishing message to remote session: %s", err) + h.logger.Printf("Error publishing message to remote session: %s", err) } return } @@ -2308,7 +2311,7 @@ func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) { } if err != nil { - log.Printf("Error publishing message to remote session: %s", err) + h.logger.Printf("Error publishing message to remote session: %s", err) } } } @@ -2330,7 +2333,7 @@ func isAllowedToControl(session Session) bool { func (h *Hub) processControlMsg(session Session, message *ClientMessage) { msg := message.Control if !isAllowedToControl(session) { - log.Printf("Ignore control message %+v from %s", msg, session.PublicId()) + h.logger.Printf("Ignore control message %+v from %s", msg, session.PublicId()) return } @@ -2398,7 +2401,7 @@ func (h *Hub) processControlMsg(session Session, message *ClientMessage) { } } if subject == "" { - log.Printf("Unknown recipient in message %+v from %s", msg, session.PublicId()) + h.logger.Printf("Unknown recipient in message %+v from %s", msg, session.PublicId()) return } @@ -2435,7 +2438,7 @@ func (h *Hub) processControlMsg(session Session, message *ClientMessage) { err = fmt.Errorf("unsupported recipient type: %s", msg.Recipient.Type) } if err != nil { - log.Printf("Error publishing message to remote session: %s", err) + h.logger.Printf("Error publishing message to remote session: %s", err) } } } @@ -2447,7 +2450,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { // Client is not connected yet. return } else if session.ClientType() != HelloClientTypeInternal { - log.Printf("Ignore internal message %+v from %s", msg, session.PublicId()) + h.logger.Printf("Ignore internal message %+v from %s", msg, session.PublicId()) return } @@ -2460,19 +2463,19 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { msg := msg.AddSession room := h.GetRoomForBackend(msg.RoomId, session.Backend()) if room == nil { - log.Printf("Ignore add session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) + h.logger.Printf("Ignore add session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) return } sessionIdData := h.newSessionIdData(session.Backend()) privateSessionId, err := h.cookie.EncodePrivate(sessionIdData) if err != nil { - log.Printf("Could not encode private virtual session id: %s", err) + h.logger.Printf("Could not encode private virtual session id: %s", err) return } publicSessionId, err := h.cookie.EncodePublic(sessionIdData) if err != nil { - log.Printf("Could not encode public virtual session id: %s", err) + h.logger.Printf("Could not encode public virtual session id: %s", err) return } @@ -2483,7 +2486,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { sess, err := NewVirtualSession(session, privateSessionId, publicSessionId, sessionIdData, msg) if err != nil { - log.Printf("Could not create virtual session %s: %s", virtualSessionId, err) + h.logger.Printf("Could not create virtual session %s: %s", virtualSessionId, err) reply := message.NewErrorServerMessage(NewError("add_failed", "Could not create virtual session.")) session.SendMessage(reply) return @@ -2498,7 +2501,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { var response BackendClientResponse if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendOcsUrl(), request, &response); err != nil { sess.Close() - log.Printf("Could not join virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) + h.logger.Printf("Could not join virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) reply := message.NewErrorServerMessage(NewError("add_failed", "Could not join virtual session.")) session.SendMessage(reply) return @@ -2506,7 +2509,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { if response.Type == "error" { sess.Close() - log.Printf("Could not join virtual session %s at backend %s: %+v", virtualSessionId, session.BackendUrl(), response.Error) + h.logger.Printf("Could not join virtual session %s at backend %s: %+v", virtualSessionId, session.BackendUrl(), response.Error) reply := message.NewErrorServerMessage(NewError("add_failed", response.Error.Error())) session.SendMessage(reply) return @@ -2516,7 +2519,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { var response BackendClientSessionResponse if err := h.backend.PerformJSONRequest(ctx, session.ParsedBackendOcsUrl(), request, &response); err != nil { sess.Close() - log.Printf("Could not add virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) + h.logger.Printf("Could not add virtual session %s at backend %s: %s", virtualSessionId, session.BackendUrl(), err) reply := message.NewErrorServerMessage(NewError("add_failed", "Could not add virtual session.")) session.SendMessage(reply) return @@ -2529,7 +2532,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { h.mu.Unlock() statsHubSessionsCurrent.WithLabelValues(session.Backend().Id(), string(sess.ClientType())).Inc() statsHubSessionsTotal.WithLabelValues(session.Backend().Id(), string(sess.ClientType())).Inc() - log.Printf("Session %s added virtual session %s with initial flags %d", session.PublicId(), sess.PublicId(), sess.Flags()) + h.logger.Printf("Session %s added virtual session %s with initial flags %d", session.PublicId(), sess.PublicId(), sess.Flags()) session.AddVirtualSession(sess) sess.SetRoom(room) room.AddSession(sess, nil) @@ -2537,7 +2540,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { msg := msg.UpdateSession room := h.GetRoomForBackend(msg.RoomId, session.Backend()) if room == nil { - log.Printf("Ignore remove session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) + h.logger.Printf("Ignore remove session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) return } @@ -2565,7 +2568,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { } } } else { - log.Printf("Ignore update request for non-virtual session %s", sess.PublicId()) + h.logger.Printf("Ignore update request for non-virtual session %s", sess.PublicId()) } if changed != 0 { room.NotifySessionChanged(sess, changed) @@ -2575,7 +2578,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { msg := msg.RemoveSession room := h.GetRoomForBackend(msg.RoomId, session.Backend()) if room == nil { - log.Printf("Ignore remove session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) + h.logger.Printf("Ignore remove session message %+v for invalid room %s from %s", *msg, msg.RoomId, session.PublicId()) return } @@ -2591,7 +2594,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { sess := h.sessions[sid] h.mu.Unlock() if sess != nil { - log.Printf("Session %s removed virtual session %s", session.PublicId(), sess.PublicId()) + h.logger.Printf("Session %s removed virtual session %s", session.PublicId(), sess.PublicId()) if vsess, ok := sess.(*VirtualSession); ok { // We should always have a VirtualSession here. vsess.CloseWithFeedback(session, message) @@ -2625,7 +2628,7 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { asyncMessage.Room.Transient.TTL = removeCallStatusTTL } if err := h.events.PublishBackendRoomMessage(roomId, session.Backend(), asyncMessage); err != nil { - log.Printf("Error publishing dialout message %+v to room %s", msg.Dialout, roomId) + h.logger.Printf("Error publishing dialout message %+v to room %s", msg.Dialout, roomId) } } else { if err := h.events.PublishRoomMessage(roomId, session.Backend(), &AsyncMessage{ @@ -2635,11 +2638,11 @@ func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) { Dialout: msg.Dialout, }, }); err != nil { - log.Printf("Error publishing dialout message %+v to room %s", msg.Dialout, roomId) + h.logger.Printf("Error publishing dialout message %+v to room %s", msg.Dialout, roomId) } } default: - log.Printf("Ignore unsupported internal message %+v from %s", msg, session.PublicId()) + h.logger.Printf("Ignore unsupported internal message %+v from %s", msg, session.PublicId()) return } } @@ -2725,7 +2728,7 @@ func (h *Hub) isInSameCallRemote(ctx context.Context, senderSession *ClientSessi if errors.Is(err, context.Canceled) { return } else if err != nil { - log.Printf("Error checking session %s in call on %s: %s", recipientSessionId, client.Target(), err) + h.logger.Printf("Error checking session %s in call on %s: %s", recipientSessionId, client.Target(), err) return } else if !inCall { return @@ -2778,14 +2781,14 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe switch data.Type { case "requestoffer": if session.PublicId() == message.Recipient.SessionId { - log.Printf("Not requesting offer from itself for session %s", session.PublicId()) + h.logger.Printf("Not requesting offer from itself for session %s", session.PublicId()) return } // A user is only allowed to subscribe a stream if she is in the same room // as the other user and both have their "inCall" flag set. if !h.allowSubscribeAnyStream && !h.isInSameCall(ctx, session, message.Recipient.SessionId) { - log.Printf("Session %s is not in the same call as session %s, not requesting offer", session.PublicId(), message.Recipient.SessionId) + h.logger.Printf("Session %s is not in the same call as session %s, not requesting offer", session.PublicId(), message.Recipient.SessionId) sendNotAllowed(session, client_message, "Not allowed to request offer.") return } @@ -2799,13 +2802,13 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe clientType = "publisher" mc, err = session.GetOrCreatePublisher(ctx, h.mcu, StreamType(data.RoomType), data) if err, ok := err.(*PermissionError); ok { - log.Printf("Session %s is not allowed to offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err) + h.logger.Printf("Session %s is not allowed to offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err) sendNotAllowed(session, client_message, "Not allowed to publish.") return } case "selectStream": if session.PublicId() == message.Recipient.SessionId { - log.Printf("Not selecting substream for own %s stream in session %s", data.RoomType, session.PublicId()) + h.logger.Printf("Not selecting substream for own %s stream in session %s", data.RoomType, session.PublicId()) return } @@ -2819,7 +2822,7 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe if session.PublicId() == message.Recipient.SessionId { if err := session.IsAllowedToSend(data); err != nil { - log.Printf("Session %s is not allowed to send candidate for %s, ignoring (%s)", session.PublicId(), data.RoomType, err) + h.logger.Printf("Session %s is not allowed to send candidate for %s, ignoring (%s)", session.PublicId(), data.RoomType, err) sendNotAllowed(session, client_message, "Not allowed to send candidate.") return } @@ -2832,11 +2835,11 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe } } if err != nil { - log.Printf("Could not create MCU %s for session %s to send %+v to %s: %s", clientType, session.PublicId(), data, message.Recipient.SessionId, err) + h.logger.Printf("Could not create MCU %s for session %s to send %+v to %s: %s", clientType, session.PublicId(), data, message.Recipient.SessionId, err) sendMcuClientNotFound(session, client_message) return } else if mc == nil { - log.Printf("No MCU %s found for session %s to send %+v to %s", clientType, session.PublicId(), data, message.Recipient.SessionId) + h.logger.Printf("No MCU %s found for session %s to send %+v to %s", clientType, session.PublicId(), data, message.Recipient.SessionId) sendMcuClientNotFound(session, client_message) return } @@ -2844,7 +2847,7 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe mc.SendMessage(session.Context(), message, data, func(err error, response api.StringMap) { if err != nil { if !errors.Is(err, ErrCandidateFiltered) { - log.Printf("Could not send MCU message %+v for session %s to %s: %s", data, session.PublicId(), message.Recipient.SessionId, err) + h.logger.Printf("Could not send MCU message %+v for session %s to %s: %s", data, session.PublicId(), message.Recipient.SessionId, err) sendMcuProcessingFailed(session, client_message) } return @@ -2871,7 +2874,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient McuClient } answer_data, err := json.Marshal(answer_message) if err != nil { - log.Printf("Could not serialize answer %+v to %s: %s", answer_message, session.PublicId(), err) + h.logger.Printf("Could not serialize answer %+v to %s: %s", answer_message, session.PublicId(), err) return } response_message = &ServerMessage{ @@ -2896,7 +2899,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient McuClient } offer_data, err := json.Marshal(offer_message) if err != nil { - log.Printf("Could not serialize offer %+v to %s: %s", offer_message, session.PublicId(), err) + h.logger.Printf("Could not serialize offer %+v to %s: %s", offer_message, session.PublicId(), err) return } response_message = &ServerMessage{ @@ -2911,7 +2914,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient McuClient }, } default: - log.Printf("Unsupported response %+v received to send to %s", response, session.PublicId()) + h.logger.Printf("Unsupported response %+v received to send to %s", response, session.PublicId()) return } @@ -2952,7 +2955,7 @@ func (h *Hub) processRoomInCallChanged(message *BackendServerRoomRequest) { if err := json.Unmarshal(message.InCall.InCall, &flags); err != nil { var incall bool if err := json.Unmarshal(message.InCall.InCall, &incall); err != nil { - log.Printf("Unsupported InCall flags type: %+v, ignoring", string(message.InCall.InCall)) + h.logger.Printf("Unsupported InCall flags type: %+v, ignoring", string(message.InCall.InCall)) return } @@ -3093,18 +3096,19 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { conn, err := h.upgrader.Upgrade(w, r, header) if err != nil { - log.Printf("Could not upgrade request from %s: %s", addr, err) + h.logger.Printf("Could not upgrade request from %s: %s", addr, err) return } + ctx := NewLoggerContext(r.Context(), h.logger) if conn.Subprotocol() == JanusEventsSubprotocol { - RunJanusEventsHandler(r.Context(), h.mcu, conn, addr, agent) + RunJanusEventsHandler(ctx, h.mcu, conn, addr, agent) return } - client, err := NewClient(r.Context(), conn, addr, agent, h) + client, err := NewClient(ctx, conn, addr, agent, h) if err != nil { - log.Printf("Could not create client for %s: %s", addr, err) + h.logger.Printf("Could not create client for %s: %s", addr, err) return } @@ -3143,7 +3147,7 @@ func (h *Hub) OnLookupCountry(client HandlerClient) string { var err error country, err = h.geoip.LookupCountry(ip) if err != nil { - log.Printf("Could not lookup country for %s: %s", ip, err) + h.logger.Printf("Could not lookup country for %s: %s", ip, err) return unknownCountry } diff --git a/hub_test.go b/hub_test.go index 12c6615..3cda700 100644 --- a/hub_test.go +++ b/hub_test.go @@ -150,6 +150,8 @@ func getTestConfigWithMultipleUrls(server *httptest.Server) (*goconf.ConfigFile, } func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, AsyncEvents, *mux.Router, *httptest.Server) { + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) r := mux.NewRouter() registerBackendHandler(t, r) @@ -162,9 +164,9 @@ func CreateHubForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Serve events := getAsyncEventsForTest(t) config, err := getConfigFunc(server) require.NoError(err) - h, err := NewHub(config, events, nil, nil, nil, r, "no-version") + h, err := NewHub(ctx, config, events, nil, nil, nil, r, "no-version") require.NoError(err) - b, err := NewBackendServer(config, h, "no-version") + b, err := NewBackendServer(ctx, config, h, "no-version") require.NoError(err) require.NoError(b.Start(r)) @@ -199,6 +201,8 @@ func CreateHubWithMultipleUrlsForTest(t *testing.T) (*Hub, AsyncEvents, *mux.Rou } func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, *Hub, *mux.Router, *mux.Router, *httptest.Server, *httptest.Server) { + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) r1 := mux.NewRouter() registerBackendHandler(t, r1) @@ -231,7 +235,7 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http addr1, addr2 = addr2, addr1 } - events1, err := NewAsyncEvents(nats1.ClientURL()) + events1, err := NewAsyncEvents(ctx, nats1.ClientURL()) require.NoError(err) t.Cleanup(func() { events1.Close() @@ -239,11 +243,11 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http config1, err := getConfigFunc(server1) require.NoError(err) client1, _ := NewGrpcClientsForTest(t, addr2) - h1, err := NewHub(config1, events1, grpcServer1, client1, nil, r1, "no-version") + h1, err := NewHub(ctx, config1, events1, grpcServer1, client1, nil, r1, "no-version") require.NoError(err) - b1, err := NewBackendServer(config1, h1, "no-version") + b1, err := NewBackendServer(ctx, config1, h1, "no-version") require.NoError(err) - events2, err := NewAsyncEvents(nats2.ClientURL()) + events2, err := NewAsyncEvents(ctx, nats2.ClientURL()) require.NoError(err) t.Cleanup(func() { events2.Close() @@ -251,9 +255,9 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http config2, err := getConfigFunc(server2) require.NoError(err) client2, _ := NewGrpcClientsForTest(t, addr1) - h2, err := NewHub(config2, events2, grpcServer2, client2, nil, r2, "no-version") + h2, err := NewHub(ctx, config2, events2, grpcServer2, client2, nil, r2, "no-version") require.NoError(err) - b2, err := NewBackendServer(config2, h2, "no-version") + b2, err := NewBackendServer(ctx, config2, h2, "no-version") require.NoError(err) require.NoError(b1.Start(r1)) require.NoError(b2.Start(r2)) @@ -820,7 +824,6 @@ func performHousekeeping(hub *Hub, now time.Time) *sync.WaitGroup { func TestWebsocketFeatures(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) _, _, _, server := CreateHubForTest(t) @@ -852,7 +855,6 @@ func TestWebsocketFeatures(t *testing.T) { func TestInitialWelcome(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -873,7 +875,6 @@ func TestInitialWelcome(t *testing.T) { func TestExpectClientHello(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -900,7 +901,6 @@ func TestExpectClientHello(t *testing.T) { func TestExpectClientHelloUnsupportedVersion(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -925,7 +925,6 @@ func TestExpectClientHelloUnsupportedVersion(t *testing.T) { func TestClientHelloV1(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -939,7 +938,6 @@ func TestClientHelloV1(t *testing.T) { } func TestClientHelloV2(t *testing.T) { - CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { require := require.New(t) @@ -976,7 +974,6 @@ func TestClientHelloV2(t *testing.T) { } func TestClientHelloV2_IssuedInFuture(t *testing.T) { - CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { require := require.New(t) @@ -1002,7 +999,6 @@ func TestClientHelloV2_IssuedInFuture(t *testing.T) { } func TestClientHelloV2_IssuedFarInFuture(t *testing.T) { - CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { require := require.New(t) @@ -1024,7 +1020,6 @@ func TestClientHelloV2_IssuedFarInFuture(t *testing.T) { } func TestClientHelloV2_Expired(t *testing.T) { - CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { require := require.New(t) @@ -1046,7 +1041,6 @@ func TestClientHelloV2_Expired(t *testing.T) { } func TestClientHelloV2_IssuedAtMissing(t *testing.T) { - CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { require := require.New(t) @@ -1068,7 +1062,6 @@ func TestClientHelloV2_IssuedAtMissing(t *testing.T) { } func TestClientHelloV2_ExpiresAtMissing(t *testing.T) { - CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { require := require.New(t) @@ -1090,7 +1083,6 @@ func TestClientHelloV2_ExpiresAtMissing(t *testing.T) { } func TestClientHelloV2_CachedCapabilities(t *testing.T) { - CatchLogForTest(t) for _, algo := range testHelloV2Algorithms { t.Run(algo, func(t *testing.T) { require := require.New(t) @@ -1129,7 +1121,6 @@ func TestClientHelloV2_CachedCapabilities(t *testing.T) { func TestClientHelloWithSpaces(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1145,7 +1136,6 @@ func TestClientHelloWithSpaces(t *testing.T) { func TestClientHelloAllowAll(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { config, err := getTestConfig(server) @@ -1167,7 +1157,6 @@ func TestClientHelloAllowAll(t *testing.T) { } func TestClientHelloSessionLimit(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -1297,7 +1286,6 @@ func TestClientHelloSessionLimit(t *testing.T) { func TestSessionIdsUnordered(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1364,7 +1352,6 @@ func TestSessionIdsUnordered(t *testing.T) { func TestClientHelloResume(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1393,7 +1380,6 @@ func TestClientHelloResume(t *testing.T) { func TestClientHelloResumeThrottle(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1451,7 +1437,6 @@ func TestClientHelloResumeThrottle(t *testing.T) { func TestClientHelloResumeExpired(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1480,7 +1465,6 @@ func TestClientHelloResumeExpired(t *testing.T) { func TestClientHelloResumeTakeover(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1515,7 +1499,6 @@ func TestClientHelloResumeTakeover(t *testing.T) { func TestClientHelloResumeOtherHub(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1565,7 +1548,6 @@ func TestClientHelloResumeOtherHub(t *testing.T) { func TestClientHelloResumePublicId(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) // Test that a client can't resume a "public" session of another user. @@ -1608,7 +1590,6 @@ func TestClientHelloResumePublicId(t *testing.T) { func TestClientHelloByeResume(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1639,7 +1620,6 @@ func TestClientHelloByeResume(t *testing.T) { func TestClientHelloResumeAndJoin(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1711,7 +1691,6 @@ func runGrpcProxyTest(t *testing.T, f func(hub1, hub2 *Hub, server1, server2 *ht } func TestClientHelloResumeProxy(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { require := require.New(t) @@ -1762,7 +1741,6 @@ func TestClientHelloResumeProxy(t *testing.T) { } func TestClientHelloResumeProxy_Takeover(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { require := require.New(t) @@ -1817,7 +1795,6 @@ func TestClientHelloResumeProxy_Takeover(t *testing.T) { } func TestClientHelloResumeProxy_Disconnect(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { require := require.New(t) @@ -1852,7 +1829,6 @@ func TestClientHelloResumeProxy_Disconnect(t *testing.T) { func TestClientHelloClient(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1867,7 +1843,6 @@ func TestClientHelloClient(t *testing.T) { func TestClientHelloClient_V3Api(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1894,7 +1869,6 @@ func TestClientHelloClient_V3Api(t *testing.T) { func TestClientHelloInternal(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -1915,7 +1889,6 @@ func TestClientHelloInternal(t *testing.T) { } func TestClientMessageToSessionId(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -1935,13 +1908,11 @@ func TestClientMessageToSessionId(t *testing.T) { hub1, hub2, server1, server2 = CreateClusteredHubsForTest(t) } - mcu1, err := NewTestMCU() - require.NoError(err) + mcu1 := NewTestMCU(t) hub1.SetMcu(mcu1) if hub1 != hub2 { - mcu2, err := NewTestMCU() - require.NoError(err) + mcu2 := NewTestMCU(t) hub2.SetMcu(mcu2) } @@ -1982,7 +1953,6 @@ func TestClientMessageToSessionId(t *testing.T) { } func TestClientControlToSessionId(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -2036,7 +2006,6 @@ func TestClientControlToSessionId(t *testing.T) { func TestClientControlMissingPermissions(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2092,7 +2061,6 @@ func TestClientControlMissingPermissions(t *testing.T) { func TestClientMessageToUserId(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2131,7 +2099,6 @@ func TestClientMessageToUserId(t *testing.T) { func TestClientControlToUserId(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2170,7 +2137,6 @@ func TestClientControlToUserId(t *testing.T) { func TestClientMessageToUserIdMultipleSessions(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2209,7 +2175,6 @@ func TestClientMessageToUserIdMultipleSessions(t *testing.T) { } func TestClientMessageToRoom(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -2272,7 +2237,6 @@ func TestClientMessageToRoom(t *testing.T) { } func TestClientControlToRoom(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -2335,7 +2299,6 @@ func TestClientControlToRoom(t *testing.T) { } func TestClientMessageToCall(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -2440,7 +2403,6 @@ func TestClientMessageToCall(t *testing.T) { } func TestClientControlToCall(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -2546,7 +2508,6 @@ func TestClientControlToCall(t *testing.T) { func TestJoinRoom(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2572,7 +2533,6 @@ func TestJoinRoom(t *testing.T) { func TestJoinRoomBackendBandwidth(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { @@ -2607,12 +2567,10 @@ func TestJoinRoomBackendBandwidth(t *testing.T) { func TestJoinRoomMcuBandwidth(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) hub.SetMcu(mcu) mcu.SetBandwidthLimits(1000, 2000) @@ -2638,7 +2596,6 @@ func TestJoinRoomMcuBandwidth(t *testing.T) { func TestJoinRoomPreferMcuBandwidth(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { @@ -2652,8 +2609,7 @@ func TestJoinRoomPreferMcuBandwidth(t *testing.T) { return config, nil }) - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) hub.SetMcu(mcu) // The MCU bandwidth limits overwrite any backend limits. @@ -2680,7 +2636,6 @@ func TestJoinRoomPreferMcuBandwidth(t *testing.T) { func TestJoinInvalidRoom(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2712,7 +2667,6 @@ func TestJoinInvalidRoom(t *testing.T) { func TestJoinRoomTwice(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2760,7 +2714,6 @@ func TestJoinRoomTwice(t *testing.T) { func TestExpectAnonymousJoinRoom(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2789,7 +2742,6 @@ func TestExpectAnonymousJoinRoom(t *testing.T) { func TestExpectAnonymousJoinRoomAfterLeave(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2841,7 +2793,6 @@ func TestExpectAnonymousJoinRoomAfterLeave(t *testing.T) { func TestJoinRoomChange(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2873,7 +2824,6 @@ func TestJoinRoomChange(t *testing.T) { func TestJoinMultiple(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2914,7 +2864,6 @@ func TestJoinMultiple(t *testing.T) { func TestJoinDisplaynamesPermission(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2962,7 +2911,6 @@ func TestJoinDisplaynamesPermission(t *testing.T) { func TestInitialRoomPermissions(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -2988,7 +2936,6 @@ func TestInitialRoomPermissions(t *testing.T) { func TestJoinRoomSwitchClient(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -3267,7 +3214,6 @@ func TestGetRealUserIP(t *testing.T) { func TestClientMessageToSessionIdWhileDisconnected(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -3312,7 +3258,6 @@ func TestClientMessageToSessionIdWhileDisconnected(t *testing.T) { func TestCombineChatRefreshWhileDisconnected(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -3387,7 +3332,6 @@ func TestCombineChatRefreshWhileDisconnected(t *testing.T) { func TestRoomParticipantsListUpdateWhileDisconnected(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -3467,7 +3411,6 @@ func TestRoomParticipantsListUpdateWhileDisconnected(t *testing.T) { } func TestClientTakeoverRoomSession(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -3555,15 +3498,13 @@ func RunTestClientTakeoverRoomSession(t *testing.T) { func TestClientSendOfferPermissions(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) require.NoError(mcu.Start(ctx)) defer mcu.Stop() @@ -3644,15 +3585,13 @@ func TestClientSendOfferPermissions(t *testing.T) { func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) require.NoError(mcu.Start(ctx)) defer mcu.Stop() @@ -3707,7 +3646,6 @@ func TestClientSendOfferPermissionsAudioOnly(t *testing.T) { func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -3715,8 +3653,7 @@ func TestClientSendOfferPermissionsAudioVideo(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) require.NoError(mcu.Start(ctx)) defer mcu.Stop() @@ -3805,7 +3742,6 @@ loop: func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -3813,8 +3749,7 @@ func TestClientSendOfferPermissionsAudioVideoMedia(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) require.NoError(mcu.Start(ctx)) defer mcu.Stop() @@ -3905,7 +3840,6 @@ loop: } func TestClientRequestOfferNotInRoom(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -3926,8 +3860,7 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) require.NoError(mcu.Start(ctx)) defer mcu.Stop() @@ -4066,7 +3999,6 @@ func TestClientRequestOfferNotInRoom(t *testing.T) { func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) // Clients can't send messages to sessions connected from other backends. hub, _, _, server := CreateHubWithMultipleBackendsForTest(t) @@ -4119,7 +4051,6 @@ func TestNoSendBetweenSessionsOnDifferentBackends(t *testing.T) { func TestSendBetweenDifferentUrls(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubWithMultipleUrlsForTest(t) @@ -4170,7 +4101,6 @@ func TestSendBetweenDifferentUrls(t *testing.T) { func TestNoSameRoomOnDifferentBackends(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubWithMultipleBackendsForTest(t) @@ -4244,7 +4174,6 @@ func TestNoSameRoomOnDifferentBackends(t *testing.T) { func TestSameRoomOnDifferentUrls(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubWithMultipleUrlsForTest(t) @@ -4309,7 +4238,6 @@ func TestSameRoomOnDifferentUrls(t *testing.T) { } func TestClientSendOffer(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -4330,8 +4258,7 @@ func TestClientSendOffer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) require.NoError(mcu.Start(ctx)) defer mcu.Stop() @@ -4390,15 +4317,13 @@ func TestClientSendOffer(t *testing.T) { func TestClientUnshareScreen(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) hub, _, _, server := CreateHubForTest(t) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - mcu, err := NewTestMCU() - require.NoError(err) + mcu := NewTestMCU(t) require.NoError(mcu.Start(ctx)) defer mcu.Stop() @@ -4455,7 +4380,6 @@ func TestClientUnshareScreen(t *testing.T) { } func TestVirtualClientSessions(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -4696,7 +4620,6 @@ func TestVirtualClientSessions(t *testing.T) { } func TestDuplicateVirtualSessions(t *testing.T) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -4970,7 +4893,6 @@ func TestDuplicateVirtualSessions(t *testing.T) { } func DoTestSwitchToOne(t *testing.T, details api.StringMap) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -5067,7 +4989,6 @@ func TestSwitchToOneList(t *testing.T) { } func DoTestSwitchToMultiple(t *testing.T, details1 api.StringMap, details2 api.StringMap) { - CatchLogForTest(t) for _, subtest := range clusteredTests { t.Run(subtest, func(t *testing.T) { t.Parallel() @@ -5175,7 +5096,6 @@ func TestSwitchToMultipleMixed(t *testing.T) { func TestGeoipOverrides(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) country1 := "DE" country2 := "IT" @@ -5201,7 +5121,8 @@ func TestGeoipOverrides(t *testing.T) { func TestDialoutStatus(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) _, _, _, hub, _, server := CreateBackendServerForTest(t) @@ -5210,7 +5131,7 @@ func TestDialoutStatus(t *testing.T) { defer internalClient.CloseWithBye() require.NoError(internalClient.SendHelloInternalWithFeatures([]string{"start-dialout"})) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() MustSucceed1(t, internalClient.RunUntilHello, ctx) @@ -5357,7 +5278,6 @@ func TestDialoutStatus(t *testing.T) { func TestGracefulShutdownInitial(t *testing.T) { t.Parallel() - CatchLogForTest(t) hub, _, _, _ := CreateHubForTest(t) hub.ScheduleShutdown() @@ -5366,7 +5286,6 @@ func TestGracefulShutdownInitial(t *testing.T) { func TestGracefulShutdownOnBye(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -5393,7 +5312,6 @@ func TestGracefulShutdownOnBye(t *testing.T) { func TestGracefulShutdownOnExpiration(t *testing.T) { t.Parallel() - CatchLogForTest(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) diff --git a/mcu_common.go b/mcu_common.go index 75d5fb4..ba5a736 100644 --- a/mcu_common.go +++ b/mcu_common.go @@ -24,7 +24,6 @@ package signaling import ( "context" "fmt" - "log" "sync/atomic" "time" @@ -82,6 +81,8 @@ type McuSettings interface { } type mcuCommonSettings struct { + logger Logger + maxStreamBitrate api.AtomicBandwidth maxScreenBitrate api.AtomicBandwidth @@ -110,7 +111,7 @@ func (s *mcuCommonSettings) load(config *goconf.ConfigFile) error { maxStreamBitrateValue = int(defaultMaxStreamBitrate.Bits()) } maxStreamBitrate := api.BandwidthFromBits(uint64(maxStreamBitrateValue)) - log.Printf("Maximum bandwidth %s per publishing stream", maxStreamBitrate) + s.logger.Printf("Maximum bandwidth %s per publishing stream", maxStreamBitrate) s.maxStreamBitrate.Store(maxStreamBitrate) maxScreenBitrateValue, _ := config.GetInt("mcu", "maxscreenbitrate") @@ -118,7 +119,7 @@ func (s *mcuCommonSettings) load(config *goconf.ConfigFile) error { maxScreenBitrateValue = int(defaultMaxScreenBitrate.Bits()) } maxScreenBitrate := api.BandwidthFromBits(uint64(maxScreenBitrateValue)) - log.Printf("Maximum bandwidth %s per screensharing stream", maxScreenBitrate) + s.logger.Printf("Maximum bandwidth %s per screensharing stream", maxScreenBitrate) s.maxScreenBitrate.Store(maxScreenBitrate) return nil } diff --git a/mcu_janus.go b/mcu_janus.go index d4558d6..7e1e29b 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -26,7 +26,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "strconv" "sync" "sync/atomic" @@ -108,7 +107,7 @@ func convertIntValue(value any) (uint64, error) { } } -func getPluginIntValue(data janus.PluginData, pluginName string, key string) uint64 { +func getPluginIntValue(logger Logger, data janus.PluginData, pluginName string, key string) uint64 { val := getPluginValue(data, pluginName, key) if val == nil { return 0 @@ -116,7 +115,7 @@ func getPluginIntValue(data janus.PluginData, pluginName string, key string) uin result, err := convertIntValue(val) if err != nil { - log.Printf("Invalid value %+v for %s: %s", val, key, err) + logger.Printf("Invalid value %+v for %s: %s", val, key, err) result = 0 } return result @@ -154,8 +153,12 @@ type mcuJanusSettings struct { blockedCandidates atomic.Pointer[AllowedIps] } -func newMcuJanusSettings(config *goconf.ConfigFile) (*mcuJanusSettings, error) { - settings := &mcuJanusSettings{} +func newMcuJanusSettings(ctx context.Context, config *goconf.ConfigFile) (*mcuJanusSettings, error) { + settings := &mcuJanusSettings{ + mcuCommonSettings: mcuCommonSettings{ + logger: LoggerFromContext(ctx), + }, + } if err := settings.load(config); err != nil { return nil, err } @@ -173,7 +176,7 @@ func (s *mcuJanusSettings) load(config *goconf.ConfigFile) error { mcuTimeoutSeconds = defaultMcuTimeoutSeconds } mcuTimeout := time.Duration(mcuTimeoutSeconds) * time.Second - log.Printf("Using a timeout of %s for MCU requests", mcuTimeout) + s.logger.Printf("Using a timeout of %s for MCU requests", mcuTimeout) s.setTimeout(mcuTimeout) if value, _ := config.GetString("mcu", "allowedcandidates"); value != "" { @@ -182,10 +185,10 @@ func (s *mcuJanusSettings) load(config *goconf.ConfigFile) error { return fmt.Errorf("invalid allowedcandidates: %w", err) } - log.Printf("Candidates allowlist: %s", allowed) + s.logger.Printf("Candidates allowlist: %s", allowed) s.allowedCandidates.Store(allowed) } else { - log.Printf("No candidates allowlist") + s.logger.Printf("No candidates allowlist") s.allowedCandidates.Store(nil) } if value, _ := config.GetString("mcu", "blockedcandidates"); value != "" { @@ -194,10 +197,10 @@ func (s *mcuJanusSettings) load(config *goconf.ConfigFile) error { return fmt.Errorf("invalid blockedcandidates: %w", err) } - log.Printf("Candidates blocklist: %s", blocked) + s.logger.Printf("Candidates blocklist: %s", blocked) s.blockedCandidates.Store(blocked) } else { - log.Printf("No candidates blocklist") + s.logger.Printf("No candidates blocklist") s.blockedCandidates.Store(nil) } @@ -206,11 +209,13 @@ func (s *mcuJanusSettings) load(config *goconf.ConfigFile) error { func (s *mcuJanusSettings) Reload(config *goconf.ConfigFile) { if err := s.load(config); err != nil { - log.Printf("Error reloading MCU settings: %s", err) + s.logger.Printf("Error reloading MCU settings: %s", err) } } type mcuJanus struct { + logger Logger + url string mu sync.Mutex @@ -251,12 +256,13 @@ func emptyOnConnected() {} func emptyOnDisconnected() {} func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mcu, error) { - settings, err := newMcuJanusSettings(config) + settings, err := newMcuJanusSettings(ctx, config) if err != nil { return nil, err } mcu := &mcuJanus{ + logger: LoggerFromContext(ctx), url: url, settings: settings, closeChan: make(chan struct{}, 1), @@ -290,18 +296,18 @@ func (m *mcuJanus) disconnect() { m.handle = nil m.closeChan <- struct{}{} if _, err := handle.Detach(context.TODO()); err != nil { - log.Printf("Error detaching handle %d: %s", handle.Id, err) + m.logger.Printf("Error detaching handle %d: %s", handle.Id, err) } } if m.session != nil { if _, err := m.session.Destroy(context.TODO()); err != nil { - log.Printf("Error destroying session %d: %s", m.session.Id, err) + m.logger.Printf("Error destroying session %d: %s", m.session.Id, err) } m.session = nil } if m.gw != nil { if err := m.gw.Close(); err != nil { - log.Println("Error while closing connection to MCU", err) + m.logger.Println("Error while closing connection to MCU", err) } m.gw = nil } @@ -371,7 +377,7 @@ func (m *mcuJanus) doReconnect(ctx context.Context) { return } - log.Println("Reconnection to Janus gateway successful") + m.logger.Println("Reconnection to Janus gateway successful") m.mu.Lock() clear(m.publishers) m.publisherCreated.Reset() @@ -407,9 +413,9 @@ func (m *mcuJanus) scheduleReconnect(err error) { defer m.mu.Unlock() m.reconnectTimer.Reset(m.reconnectInterval) if err == nil { - log.Printf("Connection to Janus gateway was interrupted, reconnecting in %s", m.reconnectInterval) + m.logger.Printf("Connection to Janus gateway was interrupted, reconnecting in %s", m.reconnectInterval) } else { - log.Printf("Reconnect to Janus gateway failed (%s), reconnecting in %s", err, m.reconnectInterval) + m.logger.Printf("Reconnect to Janus gateway failed (%s), reconnecting in %s", err, m.reconnectInterval) } m.reconnectInterval = min(m.reconnectInterval*2, maxReconnectInterval) @@ -439,49 +445,49 @@ func (m *mcuJanus) Start(ctx context.Context) error { return err } - log.Printf("Connected to %s %s by %s", info.Name, info.VersionString, info.Author) + m.logger.Printf("Connected to %s %s by %s", info.Name, info.VersionString, info.Author) m.version = info.Version if plugin, found := info.Plugins[pluginVideoRoom]; found { - log.Printf("Found %s %s by %s", plugin.Name, plugin.VersionString, plugin.Author) + m.logger.Printf("Found %s %s by %s", plugin.Name, plugin.VersionString, plugin.Author) } else { return fmt.Errorf("plugin %s is not supported", pluginVideoRoom) } if plugin, found := info.Events[eventWebsocket]; found { if !info.EventHandlers { - log.Printf("Found %s %s by %s but event handlers are disabled, realtime usage will not be available", plugin.Name, plugin.VersionString, plugin.Author) + m.logger.Printf("Found %s %s by %s but event handlers are disabled, realtime usage will not be available", plugin.Name, plugin.VersionString, plugin.Author) } else { - log.Printf("Found %s %s by %s", plugin.Name, plugin.VersionString, plugin.Author) + m.logger.Printf("Found %s %s by %s", plugin.Name, plugin.VersionString, plugin.Author) } } else { - log.Printf("Plugin %s not found, realtime usage will not be available", eventWebsocket) + m.logger.Printf("Plugin %s not found, realtime usage will not be available", eventWebsocket) } - log.Printf("Used dependencies: %+v", info.Dependencies) + m.logger.Printf("Used dependencies: %+v", info.Dependencies) if !info.DataChannels { return fmt.Errorf("data channels are not supported") } - log.Println("Data channels are supported") + m.logger.Println("Data channels are supported") if !info.FullTrickle { - log.Println("WARNING: Full-Trickle is NOT enabled in Janus!") + m.logger.Println("WARNING: Full-Trickle is NOT enabled in Janus!") } else { - log.Println("Full-Trickle is enabled") + m.logger.Println("Full-Trickle is enabled") } if m.session, err = m.gw.Create(ctx); err != nil { m.disconnect() return err } - log.Println("Created Janus session", m.session.Id) + m.logger.Println("Created Janus session", m.session.Id) m.connectedSince = time.Now() if m.handle, err = m.session.Attach(ctx, pluginVideoRoom); err != nil { m.disconnect() return err } - log.Println("Created Janus handle", m.handle.Id) + m.logger.Println("Created Janus handle", m.handle.Id) m.info.Store(info) @@ -627,7 +633,7 @@ func (m *mcuJanus) GetStats() any { func (m *mcuJanus) sendKeepalive(ctx context.Context) { if _, err := m.session.KeepAlive(ctx); err != nil { - log.Println("Could not send keepalive request", err) + m.logger.Println("Could not send keepalive request", err) if e, ok := err.(*janus.ErrorMsg); ok { switch e.Err.Code { case JANUS_ERROR_SESSION_NOT_FOUND: @@ -693,20 +699,20 @@ func (m *mcuJanus) createPublisherRoom(ctx context.Context, handle *JanusHandle, create_response, err := handle.Request(ctx, create_msg) if err != nil { if _, err2 := handle.Detach(ctx); err2 != nil { - log.Printf("Error detaching handle %d: %s", handle.Id, err2) + m.logger.Printf("Error detaching handle %d: %s", handle.Id, err2) } return 0, 0, err } - roomId := getPluginIntValue(create_response.PluginData, pluginVideoRoom, "room") + roomId := getPluginIntValue(m.logger, create_response.PluginData, pluginVideoRoom, "room") if roomId == 0 { if _, err := handle.Detach(ctx); err != nil { - log.Printf("Error detaching handle %d: %s", handle.Id, err) + m.logger.Printf("Error detaching handle %d: %s", handle.Id, err) } return 0, 0, fmt.Errorf("no room id received: %+v", create_response) } - log.Println("Created room", roomId, create_response.PluginData) + m.logger.Println("Created room", roomId, create_response.PluginData) return roomId, bitrate, nil } @@ -720,12 +726,12 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id PublicSess return nil, 0, 0, 0, err } - log.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id) + m.logger.Printf("Attached %s as publisher %d to plugin %s in session %d", streamType, handle.Id, pluginVideoRoom, session.Id) roomId, bitrate, err := m.createPublisherRoom(ctx, handle, id, streamType, settings) if err != nil { if _, err2 := handle.Detach(ctx); err2 != nil { - log.Printf("Error detaching handle %d: %s", handle.Id, err2) + m.logger.Printf("Error detaching handle %d: %s", handle.Id, err2) } return nil, 0, 0, 0, err } @@ -740,7 +746,7 @@ func (m *mcuJanus) getOrCreatePublisherHandle(ctx context.Context, id PublicSess response, err := handle.Message(ctx, msg, nil) if err != nil { if _, err2 := handle.Detach(ctx); err2 != nil { - log.Printf("Error detaching handle %d: %s", handle.Id, err2) + m.logger.Printf("Error detaching handle %d: %s", handle.Id, err2) } return nil, 0, 0, 0, err } @@ -760,6 +766,7 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id Pu client := &mcuJanusPublisher{ mcuJanusClient: mcuJanusClient{ + logger: m.logger, mcu: m, listener: listener, @@ -787,7 +794,7 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id Pu client.mcuJanusClient.handleMedia = client.handleMedia m.registerClient(client) - log.Printf("Publisher %s is using handle %d", client.id, handle.Id) + m.logger.Printf("Publisher %s is using handle %d", client.id, handle.Id) go client.run(handle, client.closeChan) m.mu.Lock() m.publishers[getStreamId(id, streamType)] = client @@ -842,7 +849,7 @@ func (m *mcuJanus) getOrCreateSubscriberHandle(ctx context.Context, publisher Pu return nil, nil, err } - log.Printf("Attached subscriber to room %d of publisher %s in plugin %s in session %d as %d", pub.roomId, publisher, pluginVideoRoom, session.Id, handle.Id) + m.logger.Printf("Attached subscriber to room %d of publisher %s in plugin %s in session %d as %d", pub.roomId, publisher, pluginVideoRoom, session.Id, handle.Id) return handle, pub, nil } @@ -858,6 +865,7 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ client := &mcuJanusSubscriber{ mcuJanusClient: mcuJanusClient{ + logger: m.logger, mcu: m, listener: listener, @@ -917,7 +925,7 @@ func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller Re roomId, maxBitrate, err := m.createPublisherRoom(ctx, handle, controller.PublisherId(), streamType, settings) if err != nil { if _, err2 := handle.Detach(ctx); err2 != nil { - log.Printf("Error detaching handle %d: %s", handle.Id, err2) + m.logger.Printf("Error detaching handle %d: %s", handle.Id, err2) } return nil, err } @@ -930,19 +938,20 @@ func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller Re }) if err != nil { if _, err2 := handle.Detach(ctx); err2 != nil { - log.Printf("Error detaching handle %d: %s", handle.Id, err2) + m.logger.Printf("Error detaching handle %d: %s", handle.Id, err2) } return nil, err } - id := getPluginIntValue(response.PluginData, pluginVideoRoom, "id") - port := getPluginIntValue(response.PluginData, pluginVideoRoom, "port") - rtcp_port := getPluginIntValue(response.PluginData, pluginVideoRoom, "rtcp_port") + id := getPluginIntValue(m.logger, response.PluginData, pluginVideoRoom, "id") + port := getPluginIntValue(m.logger, response.PluginData, pluginVideoRoom, "port") + rtcp_port := getPluginIntValue(m.logger, response.PluginData, pluginVideoRoom, "rtcp_port") pub = &mcuJanusRemotePublisher{ mcuJanusPublisher: mcuJanusPublisher{ mcuJanusClient: mcuJanusClient{ - mcu: m, + logger: m.logger, + mcu: m, id: id, session: response.Session, @@ -1018,11 +1027,12 @@ func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener return nil, err } - log.Printf("Attached subscriber to room %d of publisher %s in plugin %s in session %d as %d", pub.roomId, pub.id, pluginVideoRoom, session.Id, handle.Id) + m.logger.Printf("Attached subscriber to room %d of publisher %s in plugin %s in session %d as %d", pub.roomId, pub.id, pluginVideoRoom, session.Id, handle.Id) client := &mcuJanusRemoteSubscriber{ mcuJanusSubscriber: mcuJanusSubscriber{ mcuJanusClient: mcuJanusClient{ + logger: m.logger, mcu: m, listener: listener, diff --git a/mcu_janus_client.go b/mcu_janus_client.go index 8945199..754ec2a 100644 --- a/mcu_janus_client.go +++ b/mcu_janus_client.go @@ -23,7 +23,6 @@ package signaling import ( "context" - "log" "reflect" "strconv" "sync" @@ -35,6 +34,7 @@ import ( ) type mcuJanusClient struct { + logger Logger mcu *mcuJanus listener McuListener mu sync.Mutex @@ -127,7 +127,7 @@ func (c *mcuJanusClient) closeClient(ctx context.Context) bool { close(c.closeChan) if _, err := handle.Detach(ctx); err != nil { if e, ok := err.(*janus.ErrorMsg); !ok || e.Err.Code != JANUS_ERROR_HANDLE_NOT_FOUND { - log.Println("Could not detach client", handle.Id, err) + c.logger.Println("Could not detach client", handle.Id, err) } } return true @@ -157,7 +157,7 @@ loop: case *TrickleMsg: c.handleTrickle(t) default: - log.Println("Received unsupported event type", msg, reflect.TypeOf(msg)) + c.logger.Println("Received unsupported event type", msg, reflect.TypeOf(msg)) } case f := <-c.deferred: f() @@ -205,7 +205,8 @@ func (c *mcuJanusClient) sendAnswer(ctx context.Context, answer api.StringMap, c callback(err, nil) return } - log.Println("Started listener", start_response) + + c.logger.Println("Started listener", start_response) callback(nil, nil) } diff --git a/mcu_janus_events_handler.go b/mcu_janus_events_handler.go index 7d0ddc1..f49ba7a 100644 --- a/mcu_janus_events_handler.go +++ b/mcu_janus_events_handler.go @@ -26,7 +26,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "math" "net" "strconv" @@ -621,8 +620,9 @@ func (h *handleStats) LostRemote(media string, lost uint64) { type JanusEventsHandler struct { mu sync.Mutex - ctx context.Context - mcu McuEventHandler + logger Logger + ctx context.Context + mcu McuEventHandler // +checklocks:mu conn *websocket.Conn addr string @@ -654,7 +654,8 @@ func RunJanusEventsHandler(ctx context.Context, mcu Mcu, conn *websocket.Conn, a client, err := NewJanusEventsHandler(ctx, m, conn, addr, agent) if err != nil { - log.Printf("Could not create Janus events handler for %s: %s", addr, err) + logger := LoggerFromContext(ctx) + logger.Printf("Could not create Janus events handler for %s: %s", addr, err) conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "error creating handler"), deadline) // nolint return } @@ -664,11 +665,12 @@ func RunJanusEventsHandler(ctx context.Context, mcu Mcu, conn *websocket.Conn, a func NewJanusEventsHandler(ctx context.Context, mcu McuEventHandler, conn *websocket.Conn, addr string, agent string) (*JanusEventsHandler, error) { handler := &JanusEventsHandler{ - ctx: ctx, - mcu: mcu, - conn: conn, - addr: addr, - agent: agent, + logger: LoggerFromContext(ctx), + ctx: ctx, + mcu: mcu, + conn: conn, + addr: addr, + agent: agent, events: make(chan JanusEvent, 1), } @@ -677,7 +679,7 @@ func NewJanusEventsHandler(ctx context.Context, mcu McuEventHandler, conn *webso } func (h *JanusEventsHandler) Run() { - log.Printf("Processing Janus events from %s", h.addr) + h.logger.Printf("Processing Janus events from %s", h.addr) go h.writePump() go h.processEvents() @@ -692,7 +694,7 @@ func (h *JanusEventsHandler) close() { if conn != nil { if err := conn.Close(); err != nil { - log.Printf("Error closing %s", err) + h.logger.Printf("Error closing %s", err) } } } @@ -702,7 +704,7 @@ func (h *JanusEventsHandler) readPump() { conn := h.conn h.mu.Unlock() if conn == nil { - log.Printf("Connection from %s closed while starting readPump", h.addr) + h.logger.Printf("Connection from %s closed while starting readPump", h.addr) return } @@ -724,24 +726,24 @@ func (h *JanusEventsHandler) readPump() { websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - log.Printf("Error reading from %s: %v", h.addr, err) + h.logger.Printf("Error reading from %s: %v", h.addr, err) } break } if messageType != websocket.TextMessage { - log.Printf("Unsupported message type %v from %s", messageType, h.addr) + h.logger.Printf("Unsupported message type %v from %s", messageType, h.addr) continue } decodeBuffer, err := bufferPool.ReadAll(reader) if err != nil { - log.Printf("Error reading message from %s: %v", h.addr, err) + h.logger.Printf("Error reading message from %s: %v", h.addr, err) break } if decodeBuffer.Len() == 0 { - log.Printf("Received empty message from %s", h.addr) + h.logger.Printf("Received empty message from %s", h.addr) bufferPool.Put(decodeBuffer) break } @@ -750,7 +752,7 @@ func (h *JanusEventsHandler) readPump() { if data := decodeBuffer.Bytes(); data[0] != '[' { var event JanusEvent if err := json.Unmarshal(data, &event); err != nil { - log.Printf("Error decoding message %s from %s: %v", decodeBuffer.String(), h.addr, err) + h.logger.Printf("Error decoding message %s from %s: %v", decodeBuffer.String(), h.addr, err) bufferPool.Put(decodeBuffer) break } @@ -758,7 +760,7 @@ func (h *JanusEventsHandler) readPump() { events = append(events, event) } else { if err := json.Unmarshal(data, &events); err != nil { - log.Printf("Error decoding message %s from %s: %v", decodeBuffer.String(), h.addr, err) + h.logger.Printf("Error decoding message %s from %s: %v", decodeBuffer.String(), h.addr, err) bufferPool.Put(decodeBuffer) break } @@ -782,7 +784,7 @@ func (h *JanusEventsHandler) sendPing() bool { msg := strconv.FormatInt(now, 10) h.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := h.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { - log.Printf("Could not send ping to %s: %v", h.addr, err) + h.logger.Printf("Could not send ping to %s: %v", h.addr, err) return false } @@ -848,7 +850,7 @@ func (h *JanusEventsHandler) getHandleStats(event JanusEvent) *handleStats { func (h *JanusEventsHandler) processEvent(event JanusEvent) { evt, err := event.Decode() if err != nil { - log.Printf("Error decoding event %s (%s)", event, err) + h.logger.Printf("Error decoding event %s (%s)", event, err) return } diff --git a/mcu_janus_events_handler_test.go b/mcu_janus_events_handler_test.go index b04d604..55aed9c 100644 --- a/mcu_janus_events_handler_test.go +++ b/mcu_janus_events_handler_test.go @@ -63,7 +63,9 @@ func (h *TestJanusEventsServerHandler) ServeHTTP(w http.ResponseWriter, r *http. if host, _, err := net.SplitHostPort(addr); err == nil { addr = host } - RunJanusEventsHandler(r.Context(), h.mcu, conn, addr, r.Header.Get("User-Agent")) + logger := NewLoggerForTest(h.t) + ctx := NewLoggerContext(r.Context(), logger) + RunJanusEventsHandler(ctx, h.mcu, conn, addr, r.Header.Get("User-Agent")) return } diff --git a/mcu_janus_publisher.go b/mcu_janus_publisher.go index c0135f2..225365b 100644 --- a/mcu_janus_publisher.go +++ b/mcu_janus_publisher.go @@ -25,7 +25,6 @@ import ( "context" "errors" "fmt" - "log" "strconv" "strings" "sync/atomic" @@ -67,38 +66,38 @@ func (p *mcuJanusPublisher) handleEvent(event *janus.EventMsg) { ctx := context.TODO() switch videoroom { case "destroyed": - log.Printf("Publisher %d: associated room has been destroyed, closing", p.handleId.Load()) + p.logger.Printf("Publisher %d: associated room has been destroyed, closing", p.handleId.Load()) go p.Close(ctx) case "slow_link": // Ignore, processed through "handleSlowLink" in the general events. default: - log.Printf("Unsupported videoroom publisher event in %d: %+v", p.handleId.Load(), event) + p.logger.Printf("Unsupported videoroom publisher event in %d: %+v", p.handleId.Load(), event) } } else { - log.Printf("Unsupported publisher event in %d: %+v", p.handleId.Load(), event) + p.logger.Printf("Unsupported publisher event in %d: %+v", p.handleId.Load(), event) } } func (p *mcuJanusPublisher) handleHangup(event *janus.HangupMsg) { - log.Printf("Publisher %d received hangup (%s), closing", p.handleId.Load(), event.Reason) + p.logger.Printf("Publisher %d received hangup (%s), closing", p.handleId.Load(), event.Reason) go p.Close(context.Background()) } func (p *mcuJanusPublisher) handleDetached(event *janus.DetachedMsg) { - log.Printf("Publisher %d received detached, closing", p.handleId.Load()) + p.logger.Printf("Publisher %d received detached, closing", p.handleId.Load()) go p.Close(context.Background()) } func (p *mcuJanusPublisher) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Publisher %d received connected", p.handleId.Load()) + p.logger.Printf("Publisher %d received connected", p.handleId.Load()) p.mcu.publisherConnected.Notify(string(getStreamId(p.id, p.streamType))) } func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) { if event.Uplink { - log.Printf("Publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) + p.logger.Printf("Publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } else { - log.Printf("Publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) + p.logger.Printf("Publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } } @@ -124,21 +123,21 @@ func (p *mcuJanusPublisher) NotifyReconnected() { ctx := context.TODO() handle, session, roomId, _, err := p.mcu.getOrCreatePublisherHandle(ctx, p.id, p.streamType, p.settings) if err != nil { - log.Printf("Could not reconnect publisher %s: %s", p.id, err) + p.logger.Printf("Could not reconnect publisher %s: %s", p.id, err) // TODO(jojo): Retry return } if prev := p.handle.Swap(handle); prev != nil { if _, err := prev.Detach(context.Background()); err != nil { - log.Printf("Error detaching old publisher handle %d: %s", prev.Id, err) + p.logger.Printf("Error detaching old publisher handle %d: %s", prev.Id, err) } } p.handleId.Store(handle.Id) p.session = session p.roomId = roomId - log.Printf("Publisher %s reconnected on handle %d", p.id, p.handleId.Load()) + p.logger.Printf("Publisher %s reconnected on handle %d", p.id, p.handleId.Load()) } func (p *mcuJanusPublisher) Close(ctx context.Context) { @@ -150,9 +149,9 @@ func (p *mcuJanusPublisher) Close(ctx context.Context) { "room": p.roomId, } if _, err := handle.Request(ctx, destroy_msg); err != nil { - log.Printf("Error destroying room %d: %s", p.roomId, err) + p.logger.Printf("Error destroying room %d: %s", p.roomId, err) } else { - log.Printf("Room %d destroyed", p.roomId) + p.logger.Printf("Room %d destroyed", p.roomId) } p.mcu.mu.Lock() delete(p.mcu.publishers, getStreamId(p.id, p.streamType)) @@ -215,9 +214,9 @@ func (p *mcuJanusPublisher) SendMessage(ctx context.Context, message *MessageCli sdpString, found := api.GetStringMapEntry[string](jsep, "sdp") if !found { - log.Printf("No/invalid sdp found in answer %+v", jsep) + p.logger.Printf("No/invalid sdp found in answer %+v", jsep) } else if answerSdp, err := parseSDP(sdpString); err != nil { - log.Printf("Error parsing answer sdp %+v: %s", sdpString, err) + p.logger.Printf("Error parsing answer sdp %+v: %s", sdpString, err) p.answerSdp.Store(nil) p.sdpFlags.Remove(sdpHasAnswer) } else { @@ -355,7 +354,7 @@ func (p *mcuJanusPublisher) GetStreams(ctx context.Context) ([]PublisherStream, switch a.Key { case sdp.AttrKeyExtMap: if err := extmap.Unmarshal(extmap.Name() + ":" + a.Value); err != nil { - log.Printf("Error parsing extmap %s: %s", a.Value, err) + p.logger.Printf("Error parsing extmap %s: %s", a.Value, err) continue } @@ -388,7 +387,7 @@ func (p *mcuJanusPublisher) GetStreams(ctx context.Context) ([]PublisherStream, } else if strings.EqualFold(s.Type, "data") { // nolint // Already handled above. } else { - log.Printf("Skip type %s", s.Type) + p.logger.Printf("Skip type %s", s.Type) continue } @@ -423,7 +422,7 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId PublicSe } errorMessage := getPluginStringValue(response.PluginData, pluginVideoRoom, "error") - errorCode := getPluginIntValue(response.PluginData, pluginVideoRoom, "error_code") + errorCode := getPluginIntValue(p.logger, response.PluginData, pluginVideoRoom, "error_code") if errorMessage != "" || errorCode != 0 { if errorCode == 0 { errorCode = 500 @@ -440,7 +439,7 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId PublicSe } } - log.Printf("Publishing %s to %s (port=%d, rtcpPort=%d) for %s", p.id, hostname, port, rtcpPort, remoteId) + p.logger.Printf("Publishing %s to %s (port=%d, rtcpPort=%d) for %s", p.id, hostname, port, rtcpPort, remoteId) return nil } @@ -462,7 +461,7 @@ func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId Public } errorMessage := getPluginStringValue(response.PluginData, pluginVideoRoom, "error") - errorCode := getPluginIntValue(response.PluginData, pluginVideoRoom, "error_code") + errorCode := getPluginIntValue(p.logger, response.PluginData, pluginVideoRoom, "error_code") if errorMessage != "" || errorCode != 0 { if errorCode == 0 { errorCode = 500 @@ -479,6 +478,6 @@ func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId Public } } - log.Printf("Unpublished remote %s for %s", p.id, remoteId) + p.logger.Printf("Unpublished remote %s for %s", p.id, remoteId) return nil } diff --git a/mcu_janus_publisher_test.go b/mcu_janus_publisher_test.go index 98ae23b..31d3acf 100644 --- a/mcu_janus_publisher_test.go +++ b/mcu_janus_publisher_test.go @@ -102,7 +102,6 @@ func TestGetFmtpValueVP9(t *testing.T) { } func TestJanusPublisherRemote(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) diff --git a/mcu_janus_remote_publisher.go b/mcu_janus_remote_publisher.go index 078adc5..55ce873 100644 --- a/mcu_janus_remote_publisher.go +++ b/mcu_janus_remote_publisher.go @@ -23,7 +23,6 @@ package signaling import ( "context" - "log" "sync/atomic" "github.com/notedit/janus-go" @@ -63,38 +62,38 @@ func (p *mcuJanusRemotePublisher) handleEvent(event *janus.EventMsg) { ctx := context.TODO() switch videoroom { case "destroyed": - log.Printf("Remote publisher %d: associated room has been destroyed, closing", p.handleId.Load()) + p.logger.Printf("Remote publisher %d: associated room has been destroyed, closing", p.handleId.Load()) go p.Close(ctx) case "slow_link": // Ignore, processed through "handleSlowLink" in the general events. default: - log.Printf("Unsupported videoroom remote publisher event in %d: %+v", p.handleId.Load(), event) + p.logger.Printf("Unsupported videoroom remote publisher event in %d: %+v", p.handleId.Load(), event) } } else { - log.Printf("Unsupported remote publisher event in %d: %+v", p.handleId.Load(), event) + p.logger.Printf("Unsupported remote publisher event in %d: %+v", p.handleId.Load(), event) } } func (p *mcuJanusRemotePublisher) handleHangup(event *janus.HangupMsg) { - log.Printf("Remote publisher %d received hangup (%s), closing", p.handleId.Load(), event.Reason) + p.logger.Printf("Remote publisher %d received hangup (%s), closing", p.handleId.Load(), event.Reason) go p.Close(context.Background()) } func (p *mcuJanusRemotePublisher) handleDetached(event *janus.DetachedMsg) { - log.Printf("Remote publisher %d received detached, closing", p.handleId.Load()) + p.logger.Printf("Remote publisher %d received detached, closing", p.handleId.Load()) go p.Close(context.Background()) } func (p *mcuJanusRemotePublisher) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Remote publisher %d received connected", p.handleId.Load()) + p.logger.Printf("Remote publisher %d received connected", p.handleId.Load()) p.mcu.publisherConnected.Notify(string(getStreamId(p.id, p.streamType))) } func (p *mcuJanusRemotePublisher) handleSlowLink(event *janus.SlowLinkMsg) { if event.Uplink { - log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) + p.logger.Printf("Remote publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } else { - log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) + p.logger.Printf("Remote publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } } @@ -102,21 +101,21 @@ func (p *mcuJanusRemotePublisher) NotifyReconnected() { ctx := context.TODO() handle, session, roomId, _, err := p.mcu.getOrCreatePublisherHandle(ctx, p.id, p.streamType, p.settings) if err != nil { - log.Printf("Could not reconnect remote publisher %s: %s", p.id, err) + p.logger.Printf("Could not reconnect remote publisher %s: %s", p.id, err) // TODO(jojo): Retry return } if prev := p.handle.Swap(handle); prev != nil { if _, err := prev.Detach(context.Background()); err != nil { - log.Printf("Error detaching old remote publisher handle %d: %s", prev.Id, err) + p.logger.Printf("Error detaching old remote publisher handle %d: %s", prev.Id, err) } } p.handleId.Store(handle.Id) p.session = session p.roomId = roomId - log.Printf("Remote publisher %s reconnected on handle %d", p.id, p.handleId.Load()) + p.logger.Printf("Remote publisher %s reconnected on handle %d", p.id, p.handleId.Load()) } func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { @@ -125,7 +124,7 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { } if err := p.controller.StopPublishing(ctx, p); err != nil { - log.Printf("Error stopping remote publisher %s in room %d: %s", p.id, p.roomId, err) + p.logger.Printf("Error stopping remote publisher %s in room %d: %s", p.id, p.roomId, err) } p.mu.Lock() @@ -138,9 +137,9 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { "id": streamTypeUserIds[p.streamType], }) if err != nil { - log.Printf("Error removing remote publisher %s in room %d: %s", p.id, p.roomId, err) + p.logger.Printf("Error removing remote publisher %s in room %d: %s", p.id, p.roomId, err) } else { - log.Printf("Removed remote publisher: %+v", response) + p.logger.Printf("Removed remote publisher: %+v", response) } if p.roomId != 0 { destroy_msg := api.StringMap{ @@ -148,9 +147,9 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { "room": p.roomId, } if _, err := handle.Request(ctx, destroy_msg); err != nil { - log.Printf("Error destroying room %d: %s", p.roomId, err) + p.logger.Printf("Error destroying room %d: %s", p.roomId, err) } else { - log.Printf("Room %d destroyed", p.roomId) + p.logger.Printf("Room %d destroyed", p.roomId) } p.mcu.mu.Lock() delete(p.mcu.remotePublishers, getStreamId(p.id, p.streamType)) diff --git a/mcu_janus_remote_subscriber.go b/mcu_janus_remote_subscriber.go index 356bb65..5d984b7 100644 --- a/mcu_janus_remote_subscriber.go +++ b/mcu_janus_remote_subscriber.go @@ -23,7 +23,6 @@ package signaling import ( "context" - "log" "strconv" "sync/atomic" @@ -41,7 +40,7 @@ func (p *mcuJanusRemoteSubscriber) handleEvent(event *janus.EventMsg) { ctx := context.TODO() switch videoroom { case "destroyed": - log.Printf("Remote subscriber %d: associated room has been destroyed, closing", p.handleId.Load()) + p.logger.Printf("Remote subscriber %d: associated room has been destroyed, closing", p.handleId.Load()) go p.Close(ctx) case "event": // Handle renegotiations, but ignore other events like selected @@ -53,33 +52,33 @@ func (p *mcuJanusRemoteSubscriber) handleEvent(event *janus.EventMsg) { case "slow_link": // Ignore, processed through "handleSlowLink" in the general events. default: - log.Printf("Unsupported videoroom event %s for remote subscriber %d: %+v", videoroom, p.handleId.Load(), event) + p.logger.Printf("Unsupported videoroom event %s for remote subscriber %d: %+v", videoroom, p.handleId.Load(), event) } } else { - log.Printf("Unsupported event for remote subscriber %d: %+v", p.handleId.Load(), event) + p.logger.Printf("Unsupported event for remote subscriber %d: %+v", p.handleId.Load(), event) } } func (p *mcuJanusRemoteSubscriber) handleHangup(event *janus.HangupMsg) { - log.Printf("Remote subscriber %d received hangup (%s), closing", p.handleId.Load(), event.Reason) + p.logger.Printf("Remote subscriber %d received hangup (%s), closing", p.handleId.Load(), event.Reason) go p.Close(context.Background()) } func (p *mcuJanusRemoteSubscriber) handleDetached(event *janus.DetachedMsg) { - log.Printf("Remote subscriber %d received detached, closing", p.handleId.Load()) + p.logger.Printf("Remote subscriber %d received detached, closing", p.handleId.Load()) go p.Close(context.Background()) } func (p *mcuJanusRemoteSubscriber) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Remote subscriber %d received connected", p.handleId.Load()) + p.logger.Printf("Remote subscriber %d received connected", p.handleId.Load()) p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType) } func (p *mcuJanusRemoteSubscriber) handleSlowLink(event *janus.SlowLinkMsg) { if event.Uplink { - log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) + p.logger.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } else { - log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) + p.logger.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } } @@ -93,21 +92,21 @@ func (p *mcuJanusRemoteSubscriber) NotifyReconnected() { handle, pub, err := p.mcu.getOrCreateSubscriberHandle(ctx, p.publisher, p.streamType) if err != nil { // TODO(jojo): Retry? - log.Printf("Could not reconnect remote subscriber for publisher %s: %s", p.publisher, err) + p.logger.Printf("Could not reconnect remote subscriber for publisher %s: %s", p.publisher, err) p.Close(context.Background()) return } if prev := p.handle.Swap(handle); prev != nil { if _, err := prev.Detach(context.Background()); err != nil { - log.Printf("Error detaching old remote subscriber handle %d: %s", prev.Id, err) + p.logger.Printf("Error detaching old remote subscriber handle %d: %s", prev.Id, err) } } p.handleId.Store(handle.Id) p.roomId = pub.roomId p.sid = strconv.FormatUint(handle.Id, 10) p.listener.SubscriberSidUpdated(p) - log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId.Load()) + p.logger.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId.Load()) } func (p *mcuJanusRemoteSubscriber) Close(ctx context.Context) { diff --git a/mcu_janus_subscriber.go b/mcu_janus_subscriber.go index dfb21fe..5fc1565 100644 --- a/mcu_janus_subscriber.go +++ b/mcu_janus_subscriber.go @@ -24,7 +24,6 @@ package signaling import ( "context" "fmt" - "log" "strconv" "github.com/notedit/janus-go" @@ -47,7 +46,7 @@ func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) { ctx := context.TODO() switch videoroom { case "destroyed": - log.Printf("Subscriber %d: associated room has been destroyed, closing", p.handleId.Load()) + p.logger.Printf("Subscriber %d: associated room has been destroyed, closing", p.handleId.Load()) go p.Close(ctx) case "updated": streams, ok := getPluginValue(event.Plugindata, pluginVideoRoom, "streams").([]any) @@ -64,48 +63,48 @@ func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) { } } - log.Printf("Subscriber %d: received updated event with no active media streams, closing", p.handleId.Load()) + p.logger.Printf("Subscriber %d: received updated event with no active media streams, closing", p.handleId.Load()) go p.Close(ctx) case "event": // Handle renegotiations, but ignore other events like selected // substream / temporal layer. if getPluginStringValue(event.Plugindata, pluginVideoRoom, "configured") == "ok" && event.Jsep != nil && event.Jsep["type"] == "offer" && event.Jsep["sdp"] != nil { - log.Printf("Subscriber %d: received updated offer", p.handleId.Load()) + p.logger.Printf("Subscriber %d: received updated offer", p.handleId.Load()) p.listener.OnUpdateOffer(p, event.Jsep) } else { - log.Printf("Subscriber %d: received unsupported event %+v", p.handleId.Load(), event) + p.logger.Printf("Subscriber %d: received unsupported event %+v", p.handleId.Load(), event) } case "slow_link": // Ignore, processed through "handleSlowLink" in the general events. default: - log.Printf("Unsupported videoroom event %s for subscriber %d: %+v", videoroom, p.handleId.Load(), event) + p.logger.Printf("Unsupported videoroom event %s for subscriber %d: %+v", videoroom, p.handleId.Load(), event) } } else { - log.Printf("Unsupported event for subscriber %d: %+v", p.handleId.Load(), event) + p.logger.Printf("Unsupported event for subscriber %d: %+v", p.handleId.Load(), event) } } func (p *mcuJanusSubscriber) handleHangup(event *janus.HangupMsg) { - log.Printf("Subscriber %d received hangup (%s), closing", p.handleId.Load(), event.Reason) + p.logger.Printf("Subscriber %d received hangup (%s), closing", p.handleId.Load(), event.Reason) go p.Close(context.Background()) } func (p *mcuJanusSubscriber) handleDetached(event *janus.DetachedMsg) { - log.Printf("Subscriber %d received detached, closing", p.handleId.Load()) + p.logger.Printf("Subscriber %d received detached, closing", p.handleId.Load()) go p.Close(context.Background()) } func (p *mcuJanusSubscriber) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Subscriber %d received connected", p.handleId.Load()) + p.logger.Printf("Subscriber %d received connected", p.handleId.Load()) p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType) } func (p *mcuJanusSubscriber) handleSlowLink(event *janus.SlowLinkMsg) { if event.Uplink { - log.Printf("Subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) + p.logger.Printf("Subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } else { - log.Printf("Subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) + p.logger.Printf("Subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } } @@ -119,21 +118,21 @@ func (p *mcuJanusSubscriber) NotifyReconnected() { handle, pub, err := p.mcu.getOrCreateSubscriberHandle(ctx, p.publisher, p.streamType) if err != nil { // TODO(jojo): Retry? - log.Printf("Could not reconnect subscriber for publisher %s: %s", p.publisher, err) + p.logger.Printf("Could not reconnect subscriber for publisher %s: %s", p.publisher, err) p.Close(context.Background()) return } if prev := p.handle.Swap(handle); prev != nil { if _, err := prev.Detach(context.Background()); err != nil { - log.Printf("Error detaching old subscriber handle %d: %s", prev.Id, err) + p.logger.Printf("Error detaching old subscriber handle %d: %s", prev.Id, err) } } p.handleId.Store(handle.Id) p.roomId = pub.roomId p.sid = strconv.FormatUint(handle.Id, 10) p.listener.SubscriberSidUpdated(p) - log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId.Load()) + p.logger.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId.Load()) } func (p *mcuJanusSubscriber) closeClient(ctx context.Context) bool { @@ -193,7 +192,7 @@ retry: return } - if error_code := getPluginIntValue(join_response.Plugindata, pluginVideoRoom, "error_code"); error_code > 0 { + if error_code := getPluginIntValue(p.logger, join_response.Plugindata, pluginVideoRoom, "error_code"); error_code > 0 { switch error_code { case JANUS_VIDEOROOM_ERROR_ALREADY_JOINED: // The subscriber is already connected to the room. This can happen @@ -219,7 +218,7 @@ retry: if prev := p.handle.Swap(handle); prev != nil { if _, err := prev.Detach(context.Background()); err != nil { - log.Printf("Error detaching old subscriber handle %d: %s", prev.Id, err) + p.logger.Printf("Error detaching old subscriber handle %d: %s", prev.Id, err) } } p.handleId.Store(handle.Id) @@ -229,19 +228,19 @@ retry: p.closeChan = make(chan struct{}, 1) statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Inc() go p.run(handle, p.closeChan) - log.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId.Load()) + p.logger.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId.Load()) goto retry case JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM: fallthrough case JANUS_VIDEOROOM_ERROR_NO_SUCH_FEED: switch error_code { case JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM: - log.Printf("Publisher %s not created yet for %s, not joining room %d as subscriber", p.publisher, p.streamType, p.roomId) + p.logger.Printf("Publisher %s not created yet for %s, not joining room %d as subscriber", p.publisher, p.streamType, p.roomId) go p.Close(context.Background()) callback(fmt.Errorf("Publisher %s not created yet for %s", p.publisher, p.streamType), nil) return case JANUS_VIDEOROOM_ERROR_NO_SUCH_FEED: - log.Printf("Publisher %s not sending yet for %s, wait and retry to join room %d as subscriber", p.publisher, p.streamType, p.roomId) + p.logger.Printf("Publisher %s not sending yet for %s, wait and retry to join room %d as subscriber", p.publisher, p.streamType, p.roomId) } if !loggedNotPublishingYet { @@ -254,7 +253,7 @@ retry: callback(err, nil) return } - log.Printf("Retry subscribing %s from %s", p.streamType, p.publisher) + p.logger.Printf("Retry subscribing %s from %s", p.streamType, p.publisher) goto retry default: // TODO(jojo): Should we handle other errors, too? @@ -262,7 +261,7 @@ retry: return } } - //log.Println("Joined as listener", join_response) + //p.logger.Println("Joined as listener", join_response) p.session = join_response.Session callback(nil, join_response.Jsep) diff --git a/mcu_janus_test.go b/mcu_janus_test.go index a818f35..99eab18 100644 --- a/mcu_janus_test.go +++ b/mcu_janus_test.go @@ -590,7 +590,9 @@ func newMcuJanusForTesting(t *testing.T) (*mcuJanus, *TestJanusGateway) { if strings.Contains(t.Name(), "Filter") { config.AddOption("mcu", "blockedcandidates", "192.0.0.0/24, 192.168.0.0/16") } - mcu, err := NewMcuJanus(context.Background(), "", config) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + mcu, err := NewMcuJanus(ctx, "", config) require.NoError(t, err) t.Cleanup(func() { mcu.Stop() @@ -600,7 +602,7 @@ func newMcuJanusForTesting(t *testing.T) (*mcuJanus, *TestJanusGateway) { mcuJanus.createJanusGateway = func(ctx context.Context, wsURL string, listener GatewayListener) (JanusGatewayInterface, error) { return gateway, nil } - require.NoError(t, mcu.Start(context.Background())) + require.NoError(t, mcu.Start(ctx)) return mcuJanus, gateway } @@ -675,7 +677,6 @@ func (i *TestMcuInitiator) Country() string { } func Test_JanusPublisherFilterOffer(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -786,7 +787,6 @@ func Test_JanusPublisherFilterOffer(t *testing.T) { } func Test_JanusSubscriberFilterAnswer(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -908,7 +908,6 @@ func Test_JanusSubscriberFilterAnswer(t *testing.T) { } func Test_JanusPublisherGetStreamsAudioOnly(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -992,7 +991,6 @@ func Test_JanusPublisherGetStreamsAudioOnly(t *testing.T) { } func Test_JanusPublisherGetStreamsAudioVideo(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -1086,7 +1084,6 @@ func Test_JanusPublisherSubscriber(t *testing.T) { ResetStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("incoming")) ResetStatsValue(t, statsJanusBandwidthCurrent.WithLabelValues("outgoing")) - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) @@ -1166,7 +1163,6 @@ func Test_JanusPublisherSubscriber(t *testing.T) { } func Test_JanusSubscriberPublisher(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) @@ -1215,7 +1211,6 @@ func Test_JanusSubscriberPublisher(t *testing.T) { } func Test_JanusSubscriberRequestOffer(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -1317,7 +1312,6 @@ func Test_JanusSubscriberRequestOffer(t *testing.T) { } func Test_JanusRemotePublisher(t *testing.T) { - CatchLogForTest(t) t.Parallel() assert := assert.New(t) require := require.New(t) @@ -1409,7 +1403,6 @@ func Test_JanusSubscriberNoSuchRoom(t *testing.T) { } }) - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) @@ -1509,7 +1502,6 @@ func test_JanusSubscriberAlreadyJoined(t *testing.T) { } }) - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) @@ -1619,7 +1611,6 @@ func Test_JanusSubscriberTimeout(t *testing.T) { } }) - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) @@ -1723,7 +1714,6 @@ func Test_JanusSubscriberCloseEmptyStreams(t *testing.T) { } }) - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) @@ -1834,7 +1824,6 @@ func Test_JanusSubscriberRoomDestroyed(t *testing.T) { } }) - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) @@ -1945,7 +1934,6 @@ func Test_JanusSubscriberUpdateOffer(t *testing.T) { } }) - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) diff --git a/mcu_proxy.go b/mcu_proxy.go index 85ce622..37ee0ff 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -29,7 +29,6 @@ import ( "errors" "fmt" "iter" - "log" "math/rand/v2" "net" "net/http" @@ -80,6 +79,8 @@ type McuProxy interface { } type mcuProxyPubSubCommon struct { + logger Logger + sid string streamType StreamType maxBitrate api.Bandwidth @@ -112,7 +113,7 @@ func (c *mcuProxyPubSubCommon) doSendMessage(ctx context.Context, msg *ProxyClie } if proxyDebugMessages { - log.Printf("Response from %s: %+v", c.conn, response) + c.logger.Printf("Response from %s: %+v", c.conn, response) } if response.Type == "error" { callback(response.Error, nil) @@ -129,7 +130,7 @@ func (c *mcuProxyPubSubCommon) doProcessPayload(client McuClient, msg *PayloadPr case "offer": offer, ok := api.ConvertStringMap(msg.Payload["offer"]) if !ok { - log.Printf("Unsupported payload from %s: %+v", c.conn, msg) + c.logger.Printf("Unsupported payload from %s: %+v", c.conn, msg) return } @@ -137,7 +138,7 @@ func (c *mcuProxyPubSubCommon) doProcessPayload(client McuClient, msg *PayloadPr case "candidate": c.listener.OnIceCandidate(client, msg.Payload["candidate"]) default: - log.Printf("Unsupported payload from %s: %+v", c.conn, msg) + c.logger.Printf("Unsupported payload from %s: %+v", c.conn, msg) } } @@ -148,9 +149,11 @@ type mcuProxyPublisher struct { settings NewPublisherSettings } -func newMcuProxyPublisher(id PublicSessionId, sid string, streamType StreamType, maxBitrate api.Bandwidth, settings NewPublisherSettings, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { +func newMcuProxyPublisher(logger Logger, id PublicSessionId, sid string, streamType StreamType, maxBitrate api.Bandwidth, settings NewPublisherSettings, proxyId string, conn *mcuProxyConnection, listener McuListener) *mcuProxyPublisher { return &mcuProxyPublisher{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{ + logger: logger, + sid: sid, streamType: streamType, maxBitrate: maxBitrate, @@ -177,7 +180,7 @@ func (p *mcuProxyPublisher) SetMedia(mt MediaType) { } func (p *mcuProxyPublisher) NotifyClosed() { - log.Printf("Publisher %s at %s was closed", p.proxyId, p.conn) + p.logger.Printf("Publisher %s at %s was closed", p.proxyId, p.conn) p.listener.PublisherClosed(p) p.conn.removePublisher(p) } @@ -194,14 +197,14 @@ func (p *mcuProxyPublisher) Close(ctx context.Context) { } if response, _, err := p.conn.performSyncRequest(ctx, msg); err != nil { - log.Printf("Could not delete publisher %s at %s: %s", p.proxyId, p.conn, err) + p.logger.Printf("Could not delete publisher %s at %s: %s", p.proxyId, p.conn, err) return } else if response.Type == "error" { - log.Printf("Could not delete publisher %s at %s: %s", p.proxyId, p.conn, response.Error) + p.logger.Printf("Could not delete publisher %s at %s: %s", p.proxyId, p.conn, response.Error) return } - log.Printf("Deleted publisher %s at %s", p.proxyId, p.conn) + p.logger.Printf("Deleted publisher %s at %s", p.proxyId, p.conn) } func (p *mcuProxyPublisher) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, api.StringMap)) { @@ -229,7 +232,7 @@ func (p *mcuProxyPublisher) ProcessEvent(msg *EventProxyServerMessage) { case "publisher-closed": p.NotifyClosed() default: - log.Printf("Unsupported event from %s: %+v", p.conn, msg) + p.logger.Printf("Unsupported event from %s: %+v", p.conn, msg) } } @@ -240,9 +243,11 @@ type mcuProxySubscriber struct { publisherConn *mcuProxyConnection } -func newMcuProxySubscriber(publisherId PublicSessionId, sid string, streamType StreamType, maxBitrate api.Bandwidth, proxyId string, conn *mcuProxyConnection, listener McuListener, publisherConn *mcuProxyConnection) *mcuProxySubscriber { +func newMcuProxySubscriber(logger Logger, publisherId PublicSessionId, sid string, streamType StreamType, maxBitrate api.Bandwidth, proxyId string, conn *mcuProxyConnection, listener McuListener, publisherConn *mcuProxyConnection) *mcuProxySubscriber { return &mcuProxySubscriber{ mcuProxyPubSubCommon: mcuProxyPubSubCommon{ + logger: logger, + sid: sid, streamType: streamType, maxBitrate: maxBitrate, @@ -262,9 +267,9 @@ func (s *mcuProxySubscriber) Publisher() PublicSessionId { func (s *mcuProxySubscriber) NotifyClosed() { if s.publisherConn != nil { - log.Printf("Remote subscriber %s at %s (forwarded to %s) was closed", s.proxyId, s.conn, s.publisherConn) + s.logger.Printf("Remote subscriber %s at %s (forwarded to %s) was closed", s.proxyId, s.conn, s.publisherConn) } else { - log.Printf("Subscriber %s at %s was closed", s.proxyId, s.conn) + s.logger.Printf("Subscriber %s at %s was closed", s.proxyId, s.conn) } s.listener.SubscriberClosed(s) s.conn.removeSubscriber(s) @@ -283,24 +288,24 @@ func (s *mcuProxySubscriber) Close(ctx context.Context) { if response, _, err := s.conn.performSyncRequest(ctx, msg); err != nil { if s.publisherConn != nil { - log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, err) + s.logger.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, err) } else { - log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, err) + s.logger.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, err) } return } else if response.Type == "error" { if s.publisherConn != nil { - log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, response.Error) + s.logger.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", s.proxyId, s.conn, s.publisherConn, response.Error) } else { - log.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, response.Error) + s.logger.Printf("Could not delete subscriber %s at %s: %s", s.proxyId, s.conn, response.Error) } return } if s.publisherConn != nil { - log.Printf("Deleted remote subscriber %s at %s (forwarded to %s)", s.proxyId, s.conn, s.publisherConn) + s.logger.Printf("Deleted remote subscriber %s at %s (forwarded to %s)", s.proxyId, s.conn, s.publisherConn) } else { - log.Printf("Deleted subscriber %s at %s", s.proxyId, s.conn) + s.logger.Printf("Deleted subscriber %s at %s", s.proxyId, s.conn) } } @@ -332,13 +337,14 @@ func (s *mcuProxySubscriber) ProcessEvent(msg *EventProxyServerMessage) { case "subscriber-closed": s.NotifyClosed() default: - log.Printf("Unsupported event from %s: %+v", s.conn, msg) + s.logger.Printf("Unsupported event from %s: %+v", s.conn, msg) } } type mcuProxyCallback func(response *ProxyServerMessage) type mcuProxyConnection struct { + logger Logger proxy *mcuProxy rawUrl string url *url.URL @@ -395,6 +401,7 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP, token str } conn := &mcuProxyConnection{ + logger: proxy.logger, proxy: proxy, rawUrl: baseUrl, url: parsed, @@ -598,7 +605,7 @@ func (c *mcuProxyConnection) readPump() { rtt := now.Sub(time.Unix(0, ts)) if rtt >= rttLogDuration { rtt_ms := rtt.Nanoseconds() / time.Millisecond.Nanoseconds() - log.Printf("Proxy at %s has RTT of %d ms (%s)", c, rtt_ms, rtt) + c.logger.Printf("Proxy at %s has RTT of %d ms (%s)", c, rtt_ms, rtt) } } return nil @@ -614,14 +621,14 @@ func (c *mcuProxyConnection) readPump() { websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - log.Printf("Error reading from %s: %v", c, err) + c.logger.Printf("Error reading from %s: %v", c, err) } break } var msg ProxyServerMessage if err := json.Unmarshal(message, &msg); err != nil { - log.Printf("Error unmarshaling %s from %s: %s", string(message), c, err) + c.logger.Printf("Error unmarshaling %s from %s: %s", string(message), c, err) continue } @@ -640,7 +647,7 @@ func (c *mcuProxyConnection) sendPing() bool { msg := strconv.FormatInt(now.UnixNano(), 10) c.conn.SetWriteDeadline(now.Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { - log.Printf("Could not send ping to proxy at %s: %v", c, err) + c.logger.Printf("Could not send ping to proxy at %s: %v", c, err) go c.scheduleReconnect() return false } @@ -692,7 +699,7 @@ func (c *mcuProxyConnection) stop(ctx context.Context) { c.closer.Close() if err := c.sendClose(); err != nil { if err != ErrNotConnected { - log.Printf("Could not send close message to %s: %s", c, err) + c.logger.Printf("Could not send close message to %s: %s", c, err) } c.close() return @@ -702,7 +709,7 @@ func (c *mcuProxyConnection) stop(ctx context.Context) { case <-c.closedDone.C: case <-ctx.Done(): if err := ctx.Err(); err != nil { - log.Printf("Error waiting for connection to %s get closed: %s", c, err) + c.logger.Printf("Error waiting for connection to %s get closed: %s", c, err) c.close() } } @@ -742,7 +749,7 @@ func (c *mcuProxyConnection) closeIfEmpty() bool { c.subscribersLock.RUnlock() if total > 0 { // Connection will be closed once all clients have disconnected. - log.Printf("Connection to %s is still used by %d clients, defer closing", c, total) + c.logger.Printf("Connection to %s is still used by %d clients, defer closing", c, total) return false } @@ -750,7 +757,7 @@ func (c *mcuProxyConnection) closeIfEmpty() bool { ctx, cancel := context.WithTimeout(context.Background(), closeTimeout) defer cancel() - log.Printf("All clients disconnected, closing connection to %s", c) + c.logger.Printf("All clients disconnected, closing connection to %s", c) c.stop(ctx) statsProxyBackendLoadCurrent.DeleteLabelValues(c.url.String()) @@ -766,7 +773,7 @@ func (c *mcuProxyConnection) closeIfEmpty() bool { func (c *mcuProxyConnection) scheduleReconnect() { if err := c.sendClose(); err != nil && err != ErrNotConnected { - log.Printf("Could not send close message to %s: %s", c, err) + c.logger.Printf("Could not send close message to %s: %s", c, err) } c.close() @@ -783,7 +790,7 @@ func (c *mcuProxyConnection) scheduleReconnect() { func (c *mcuProxyConnection) reconnect() { u, err := c.url.Parse("proxy") if err != nil { - log.Printf("Could not resolve url to proxy at %s: %s", c, err) + c.logger.Printf("Could not resolve url to proxy at %s: %s", c, err) c.scheduleReconnect() return } @@ -813,12 +820,12 @@ func (c *mcuProxyConnection) reconnect() { } conn, _, err := dialer.Dial(u.String(), nil) if err != nil { - log.Printf("Could not connect to %s: %s", c, err) + c.logger.Printf("Could not connect to %s: %s", c, err) c.scheduleReconnect() return } - log.Printf("Connected to %s", c) + c.logger.Printf("Connected to %s", c) c.closed.Store(false) c.helloProcessed.Store(false) c.connectedSince.Store(time.Now().UnixMicro()) @@ -830,7 +837,7 @@ func (c *mcuProxyConnection) reconnect() { c.reconnectInterval.Store(int64(initialReconnectInterval)) c.shutdownScheduled.Store(false) if err := c.sendHello(); err != nil { - log.Printf("Could not send hello request to %s: %s", c, err) + c.logger.Printf("Could not send hello request to %s: %s", c, err) c.scheduleReconnect() return } @@ -976,19 +983,19 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { switch msg.Type { case "error": if msg.Error.Code == "no_such_session" { - log.Printf("Session %s could not be resumed on %s, registering new", c.SessionId(), c) + c.logger.Printf("Session %s could not be resumed on %s, registering new", c.SessionId(), c) c.clearPublishers() c.clearSubscribers() c.clearCallbacks() c.sessionId.Store(PublicSessionId("")) if err := c.sendHello(); err != nil { - log.Printf("Could not send hello request to %s: %s", c, err) + c.logger.Printf("Could not send hello request to %s: %s", c, err) c.scheduleReconnect() } return } - log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) + c.logger.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) c.scheduleReconnect() case "hello": resumed := c.SessionId() == msg.Hello.SessionId @@ -996,7 +1003,7 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { country := "" if server := msg.Hello.Server; server != nil { if country = server.Country; country != "" && !IsValidCountry(country) { - log.Printf("Proxy %s sent invalid country %s in hello response", c, country) + c.logger.Printf("Proxy %s sent invalid country %s in hello response", c, country) country = "" } c.version.Store(server.Version) @@ -1010,11 +1017,11 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { } c.country.Store(country) if resumed { - log.Printf("Resumed session %s on %s", c.SessionId(), c) + c.logger.Printf("Resumed session %s on %s", c.SessionId(), c) } else if country != "" { - log.Printf("Received session %s from %s (in %s)", c.SessionId(), c, country) + c.logger.Printf("Received session %s from %s (in %s)", c.SessionId(), c, country) } else { - log.Printf("Received session %s from %s", c.SessionId(), c) + c.logger.Printf("Received session %s from %s", c.SessionId(), c) } if c.trackClose.CompareAndSwap(false, true) { statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Inc() @@ -1023,14 +1030,14 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { c.helloProcessed.Store(true) c.connectedNotifier.Notify() default: - log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c) + c.logger.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c) c.scheduleReconnect() } return } if proxyDebugMessages { - log.Printf("Received from %s: %+v", c, msg) + c.logger.Printf("Received from %s: %+v", c, msg) } if callback := c.getCallback(msg.Id); callback != nil { callback(msg) @@ -1048,7 +1055,7 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { case "bye": c.processBye(msg) default: - log.Printf("Unsupported message received from %s: %+v", c, msg) + c.logger.Printf("Unsupported message received from %s: %+v", c, msg) } } @@ -1070,25 +1077,25 @@ func (c *mcuProxyConnection) processPayload(msg *ProxyServerMessage) { return } - log.Printf("Received payload for unknown client %+v from %s", payload, c) + c.logger.Printf("Received payload for unknown client %+v from %s", payload, c) } func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) { event := msg.Event switch event.Type { case "backend-disconnected": - log.Printf("Upstream backend at %s got disconnected, reset MCU objects", c) + c.logger.Printf("Upstream backend at %s got disconnected, reset MCU objects", c) c.clearPublishers() c.clearSubscribers() c.clearCallbacks() // TODO: Should we also reconnect? return case "backend-connected": - log.Printf("Upstream backend at %s is connected", c) + c.logger.Printf("Upstream backend at %s is connected", c) return case "update-load": if proxyDebugMessages { - log.Printf("Load of %s now at %d (%s)", c, event.Load, event.Bandwidth) + c.logger.Printf("Load of %s now at %d (%s)", c, event.Load, event.Bandwidth) } c.load.Store(event.Load) c.bandwidth.Store(event.Bandwidth) @@ -1109,13 +1116,13 @@ func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) { } return case "shutdown-scheduled": - log.Printf("Proxy %s is scheduled to shutdown", c) + c.logger.Printf("Proxy %s is scheduled to shutdown", c) c.shutdownScheduled.Store(true) return } if proxyDebugMessages { - log.Printf("Process event from %s: %+v", c, event) + c.logger.Printf("Process event from %s: %+v", c, event) } c.publishersLock.RLock() publisher, found := c.publishers[event.ClientId] @@ -1133,20 +1140,20 @@ func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) { return } - log.Printf("Received event for unknown client %+v from %s", event, c) + c.logger.Printf("Received event for unknown client %+v from %s", event, c) } func (c *mcuProxyConnection) processBye(msg *ProxyServerMessage) { bye := msg.Bye switch bye.Reason { case "session_resumed": - log.Printf("Session %s on %s was resumed by other client, resetting", c.SessionId(), c) + c.logger.Printf("Session %s on %s was resumed by other client, resetting", c.SessionId(), c) case "session_expired": - log.Printf("Session %s expired on %s, resetting", c.SessionId(), c) + c.logger.Printf("Session %s expired on %s, resetting", c.SessionId(), c) case "session_closed": - log.Printf("Session %s was closed on %s, resetting", c.SessionId(), c) + c.logger.Printf("Session %s was closed on %s, resetting", c.SessionId(), c) default: - log.Printf("Received bye with unsupported reason from %s %+v", c, bye) + c.logger.Printf("Received bye with unsupported reason from %s %+v", c, bye) } c.sessionId.Store(PublicSessionId("")) } @@ -1185,7 +1192,7 @@ func (c *mcuProxyConnection) sendMessage(msg *ProxyClientMessage) error { // +checklocks:c.mu func (c *mcuProxyConnection) sendMessageLocked(msg *ProxyClientMessage) error { if proxyDebugMessages { - log.Printf("Send message to %s: %+v", c, msg) + c.logger.Printf("Send message to %s: %+v", c, msg) } if c.conn == nil { return ErrNotConnected @@ -1246,12 +1253,12 @@ func (c *mcuProxyConnection) performSyncRequest(ctx context.Context, msg *ProxyC func (c *mcuProxyConnection) deferredDeletePublisher(id PublicSessionId, streamType StreamType, response *ProxyServerMessage) { if response.Type == "error" { - log.Printf("Publisher for %s was not created at %s: %s", id, c, response.Error) + c.logger.Printf("Publisher for %s was not created at %s: %s", id, c, response.Error) return } proxyId := response.Command.Id - log.Printf("Created unused %s publisher %s on %s for %s", streamType, proxyId, c, id) + c.logger.Printf("Created unused %s publisher %s on %s for %s", streamType, proxyId, c, id) msg := &ProxyClientMessage{ Type: "command", Command: &CommandProxyClientMessage{ @@ -1264,14 +1271,14 @@ func (c *mcuProxyConnection) deferredDeletePublisher(id PublicSessionId, streamT defer cancel() if response, _, err := c.performSyncRequest(ctx, msg); err != nil { - log.Printf("Could not delete publisher %s at %s: %s", proxyId, c, err) + c.logger.Printf("Could not delete publisher %s at %s: %s", proxyId, c, err) return } else if response.Type == "error" { - log.Printf("Could not delete publisher %s at %s: %s", proxyId, c, response.Error) + c.logger.Printf("Could not delete publisher %s at %s: %s", proxyId, c, response.Error) return } - log.Printf("Deleted publisher %s at %s", proxyId, c) + c.logger.Printf("Deleted publisher %s at %s", proxyId, c) } func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListener, id PublicSessionId, sid string, streamType StreamType, settings NewPublisherSettings) (McuPublisher, error) { @@ -1301,8 +1308,8 @@ func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListe } proxyId := response.Command.Id - log.Printf("Created %s publisher %s on %s for %s", streamType, proxyId, c, id) - publisher := newMcuProxyPublisher(id, sid, streamType, response.Command.Bitrate, settings, proxyId, c, listener) + c.logger.Printf("Created %s publisher %s on %s for %s", streamType, proxyId, c, id) + publisher := newMcuProxyPublisher(c.logger, id, sid, streamType, response.Command.Bitrate, settings, proxyId, c, listener) c.publishersLock.Lock() c.publishers[proxyId] = publisher c.publisherIds[getStreamId(id, streamType)] = PublicSessionId(proxyId) @@ -1314,12 +1321,12 @@ func (c *mcuProxyConnection) newPublisher(ctx context.Context, listener McuListe func (c *mcuProxyConnection) deferredDeleteSubscriber(publisherSessionId PublicSessionId, streamType StreamType, publisherConn *mcuProxyConnection, response *ProxyServerMessage) { if response.Type == "error" { - log.Printf("Subscriber for %s was not created at %s: %s", publisherSessionId, c, response.Error) + c.logger.Printf("Subscriber for %s was not created at %s: %s", publisherSessionId, c, response.Error) return } proxyId := response.Command.Id - log.Printf("Created unused %s subscriber %s on %s for %s", streamType, proxyId, c, publisherSessionId) + c.logger.Printf("Created unused %s subscriber %s on %s for %s", streamType, proxyId, c, publisherSessionId) msg := &ProxyClientMessage{ Type: "command", @@ -1334,24 +1341,24 @@ func (c *mcuProxyConnection) deferredDeleteSubscriber(publisherSessionId PublicS if response, _, err := c.performSyncRequest(ctx, msg); err != nil { if publisherConn != nil { - log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", proxyId, c, publisherConn, err) + c.logger.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", proxyId, c, publisherConn, err) } else { - log.Printf("Could not delete subscriber %s at %s: %s", proxyId, c, err) + c.logger.Printf("Could not delete subscriber %s at %s: %s", proxyId, c, err) } return } else if response.Type == "error" { if publisherConn != nil { - log.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", proxyId, c, publisherConn, response.Error) + c.logger.Printf("Could not delete remote subscriber %s at %s (forwarded to %s): %s", proxyId, c, publisherConn, response.Error) } else { - log.Printf("Could not delete subscriber %s at %s: %s", proxyId, c, response.Error) + c.logger.Printf("Could not delete subscriber %s at %s: %s", proxyId, c, response.Error) } return } if publisherConn != nil { - log.Printf("Deleted remote subscriber %s at %s (forwarded to %s)", proxyId, c, publisherConn) + c.logger.Printf("Deleted remote subscriber %s at %s (forwarded to %s)", proxyId, c, publisherConn) } else { - log.Printf("Deleted subscriber %s at %s", proxyId, c) + c.logger.Printf("Deleted subscriber %s at %s", proxyId, c) } } @@ -1378,8 +1385,8 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList } proxyId := response.Command.Id - log.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c, publisherSessionId) - subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, nil) + c.logger.Printf("Created %s subscriber %s on %s for %s", streamType, proxyId, c, publisherSessionId) + subscriber := newMcuProxySubscriber(c.logger, publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, nil) c.subscribersLock.Lock() c.subscribers[proxyId] = subscriber c.subscribersLock.Unlock() @@ -1425,8 +1432,8 @@ func (c *mcuProxyConnection) newRemoteSubscriber(ctx context.Context, listener M } proxyId := response.Command.Id - log.Printf("Created remote %s subscriber %s on %s for %s (forwarded to %s)", streamType, proxyId, c, publisherSessionId, publisherConn) - subscriber := newMcuProxySubscriber(publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, publisherConn) + c.logger.Printf("Created remote %s subscriber %s on %s for %s (forwarded to %s)", streamType, proxyId, c, publisherSessionId, publisherConn) + subscriber := newMcuProxySubscriber(c.logger, publisherSessionId, response.Command.Sid, streamType, response.Command.Bitrate, proxyId, c, listener, publisherConn) c.subscribersLock.Lock() c.subscribers[proxyId] = subscriber c.subscribersLock.Unlock() @@ -1439,8 +1446,12 @@ type mcuProxySettings struct { mcuCommonSettings } -func newMcuProxySettings(config *goconf.ConfigFile) (McuSettings, error) { - settings := &mcuProxySettings{} +func newMcuProxySettings(ctx context.Context, config *goconf.ConfigFile) (McuSettings, error) { + settings := &mcuProxySettings{ + mcuCommonSettings: mcuCommonSettings{ + logger: LoggerFromContext(ctx), + }, + } if err := settings.load(config); err != nil { return nil, err } @@ -1458,18 +1469,19 @@ func (s *mcuProxySettings) load(config *goconf.ConfigFile) error { proxyTimeoutSeconds = defaultProxyTimeoutSeconds } proxyTimeout := time.Duration(proxyTimeoutSeconds) * time.Second - log.Printf("Using a timeout of %s for proxy requests", proxyTimeout) + s.logger.Printf("Using a timeout of %s for proxy requests", proxyTimeout) s.setTimeout(proxyTimeout) return nil } func (s *mcuProxySettings) Reload(config *goconf.ConfigFile) { if err := s.load(config); err != nil { - log.Printf("Error reloading proxy settings: %s", err) + s.logger.Printf("Error reloading proxy settings: %s", err) } } type mcuProxy struct { + logger Logger urlType string tokenId string tokenKey *rsa.PrivateKey @@ -1497,7 +1509,8 @@ type mcuProxy struct { rpcClients *GrpcClients } -func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients *GrpcClients, dnsMonitor *DnsMonitor) (Mcu, error) { +func NewMcuProxy(ctx context.Context, config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients *GrpcClients, dnsMonitor *DnsMonitor) (Mcu, error) { + logger := LoggerFromContext(ctx) urlType, _ := config.GetString("mcu", "urltype") if urlType == "" { urlType = proxyUrlTypeStatic @@ -1520,12 +1533,13 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients * return nil, fmt.Errorf("could not parse private key from %s: %s", tokenKeyFilename, err) } - settings, err := newMcuProxySettings((config)) + settings, err := newMcuProxySettings(ctx, config) if err != nil { return nil, err } mcu := &mcuProxy{ + logger: logger, urlType: urlType, tokenId: tokenId, tokenKey: tokenKey, @@ -1548,7 +1562,7 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients * skipverify, _ := config.GetBool("mcu", "skipverify") if skipverify { - log.Println("WARNING: MCU verification is disabled!") + logger.Println("WARNING: MCU verification is disabled!") mcu.dialer.TLSClientConfig = &tls.Config{ InsecureSkipVerify: skipverify, } @@ -1556,9 +1570,9 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients * switch urlType { case proxyUrlTypeStatic: - mcu.config, err = NewProxyConfigStatic(config, mcu, dnsMonitor) + mcu.config, err = NewProxyConfigStatic(logger, config, mcu, dnsMonitor) case proxyUrlTypeEtcd: - mcu.config, err = NewProxyConfigEtcd(config, etcdClient, mcu) + mcu.config, err = NewProxyConfigEtcd(logger, config, etcdClient, mcu) default: err = fmt.Errorf("unsupported proxy URL type %s", urlType) } @@ -1588,7 +1602,7 @@ func (m *mcuProxy) loadContinentsMap(config *goconf.ConfigFile) error { for option, value := range options { option = strings.ToUpper(strings.TrimSpace(option)) if !IsValidContinent(option) { - log.Printf("Ignore unknown continent %s", option) + m.logger.Printf("Ignore unknown continent %s", option) continue } @@ -1596,18 +1610,18 @@ func (m *mcuProxy) loadContinentsMap(config *goconf.ConfigFile) error { for v := range SplitEntries(value, ",") { v = strings.ToUpper(v) if !IsValidContinent(v) { - log.Printf("Ignore unknown continent %s for override %s", v, option) + m.logger.Printf("Ignore unknown continent %s for override %s", v, option) continue } values = append(values, v) } if len(values) == 0 { - log.Printf("No valid values found for continent override %s, ignoring", option) + m.logger.Printf("No valid values found for continent override %s, ignoring", option) continue } continentsMap[option] = values - log.Printf("Mapping users on continent %s to %s", option, values) + m.logger.Printf("Mapping users on continent %s to %s", option, values) } m.setContinentsMap(continentsMap) @@ -1701,7 +1715,7 @@ func (m *mcuProxy) AddConnection(ignoreErrors bool, url string, ips ...net.IP) e conn, err := newMcuProxyConnection(m, url, nil, "") if err != nil { if ignoreErrors { - log.Printf("Could not create proxy connection to %s: %s", url, err) + m.logger.Printf("Could not create proxy connection to %s: %s", url, err) return nil } @@ -1714,7 +1728,7 @@ func (m *mcuProxy) AddConnection(ignoreErrors bool, url string, ips ...net.IP) e conn, err := newMcuProxyConnection(m, url, ip, "") if err != nil { if ignoreErrors { - log.Printf("Could not create proxy connection to %s (%s): %s", url, ip, err) + m.logger.Printf("Could not create proxy connection to %s (%s): %s", url, ip, err) continue } @@ -1726,7 +1740,7 @@ func (m *mcuProxy) AddConnection(ignoreErrors bool, url string, ips ...net.IP) e } for _, conn := range conns { - log.Printf("Adding new connection to %s", conn) + m.logger.Printf("Adding new connection to %s", conn) conn.start() m.connections = append(m.connections, conn) @@ -1773,7 +1787,7 @@ func (m *mcuProxy) iterateConnections(url string, ips []net.IP) iter.Seq[*mcuPro func (m *mcuProxy) RemoveConnection(url string, ips ...net.IP) { for conn := range m.iterateConnections(url, ips) { - log.Printf("Removing connection to %s", conn) + m.logger.Printf("Removing connection to %s", conn) conn.closeIfEmpty() } } @@ -1793,11 +1807,11 @@ func (m *mcuProxy) Reload(config *goconf.ConfigFile) { } if err := m.loadContinentsMap(config); err != nil { - log.Printf("Error loading continents map: %s", err) + m.logger.Printf("Error loading continents map: %s", err) } if err := m.config.Reload(config); err != nil { - log.Printf("could not reload proxy configuration: %s", err) + m.logger.Printf("could not reload proxy configuration: %s", err) } } @@ -2033,7 +2047,7 @@ func (m *mcuProxy) createPublisher(ctx context.Context, listener McuListener, id publisher, err := conn.newPublisher(subctx, listener, id, sid, streamType, publisherSettings) if err != nil { - log.Printf("Could not create %s publisher for %s on %s: %s", streamType, id, conn, err) + m.logger.Printf("Could not create %s publisher for %s on %s: %s", streamType, id, conn, err) continue } @@ -2160,7 +2174,7 @@ func (m *mcuProxy) createSubscriber(ctx context.Context, listener McuListener, i subscriber, err = conn.newRemoteSubscriber(subctx, listener, info.id, publisher, streamType, info.conn, info.token) } if err != nil { - log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, conn, err) + m.logger.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, conn, err) continue } @@ -2186,7 +2200,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ conn: conn, } } else { - log.Printf("No %s publisher %s found yet, deferring", streamType, publisher) + m.logger.Printf("No %s publisher %s found yet, deferring", streamType, publisher) ch := make(chan *proxyPublisherInfo, 1) getctx, cancel := context.WithCancel(ctx) defer cancel() @@ -2227,7 +2241,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ if errors.Is(err, context.Canceled) { return } else if err != nil { - log.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err) + m.logger.Printf("Error getting %s publisher id %s from %s: %s", streamType, publisher, client.Target(), err) return } else if id == "" { // Publisher not found on other server @@ -2235,7 +2249,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ } cancel() // Cancel pending RPC calls. - log.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url) + m.logger.Printf("Found publisher id %s through %s on proxy %s", id, client.Target(), url) m.connectionsMu.RLock() connections := m.connections @@ -2254,14 +2268,14 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ if publisherConn == nil { publisherConn, err = newMcuProxyConnection(m, url, ip, connectToken) if err != nil { - log.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err) + m.logger.Printf("Could not create temporary connection to %s for %s publisher %s: %s", url, streamType, publisher, err) return } publisherConn.setTemporary() publisherConn.start() if err := publisherConn.waitUntilConnected(ctx); err != nil { - log.Printf("Could not establish new connection to %s: %s", publisherConn, err) + m.logger.Printf("Could not establish new connection to %s: %s", publisherConn, err) publisherConn.closeIfEmpty() return } @@ -2366,7 +2380,7 @@ func (m *mcuProxy) NewSubscriber(ctx context.Context, listener McuListener, publ if publisherInfo.conn.IsTemporary() { publisherInfo.conn.closeIfEmpty() } - log.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, publisherInfo.conn, err) + m.logger.Printf("Could not create subscriber for %s publisher %s on %s: %s", streamType, publisher, publisherInfo.conn, err) return nil, err } diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index 268eaa7..258f92a 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -883,19 +883,21 @@ func newMcuProxyForTestWithOptions(t *testing.T, options proxyTestOptions, idx i etcdConfig.AddOption("etcd", "endpoints", options.etcd.Config().ListenClientUrls[0].String()) etcdConfig.AddOption("etcd", "loglevel", "error") - etcdClient, err := NewEtcdClient(etcdConfig, "") + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + etcdClient, err := NewEtcdClient(logger, etcdConfig, "") require.NoError(err) t.Cleanup(func() { assert.NoError(t, etcdClient.Close()) }) - mcu, err := NewMcuProxy(cfg, etcdClient, grpcClients, dnsMonitor) + mcu, err := NewMcuProxy(ctx, cfg, etcdClient, grpcClients, dnsMonitor) require.NoError(err) t.Cleanup(func() { mcu.Stop() }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() require.NoError(mcu.Start(ctx)) @@ -942,11 +944,10 @@ func newMcuProxyForTest(t *testing.T, idx int) *mcuProxy { } func Test_ProxyAddRemoveConnections(t *testing.T) { - CatchLogForTest(t) t.Parallel() assert := assert.New(t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() server1 := NewProxyServerForTest(t, "DE") @@ -1017,7 +1018,6 @@ func Test_ProxyAddRemoveConnections(t *testing.T) { } func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) { - CatchLogForTest(t) assert := assert.New(t) require := require.New(t) @@ -1039,7 +1039,7 @@ func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) { ip1, }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() mcu, _ := newMcuProxyForTestWithOptions(t, proxyTestOptions{ @@ -1134,11 +1134,10 @@ func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) { } func Test_ProxyPublisherSubscriber(t *testing.T) { - CatchLogForTest(t) t.Parallel() mcu := newMcuProxyForTest(t, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1170,11 +1169,10 @@ func Test_ProxyPublisherSubscriber(t *testing.T) { } func Test_ProxyPublisherCodecs(t *testing.T) { - CatchLogForTest(t) t.Parallel() mcu := newMcuProxyForTest(t, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1197,11 +1195,10 @@ func Test_ProxyPublisherCodecs(t *testing.T) { } func Test_ProxyWaitForPublisher(t *testing.T) { - CatchLogForTest(t) t.Parallel() mcu := newMcuProxyForTest(t, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1247,7 +1244,6 @@ func Test_ProxyWaitForPublisher(t *testing.T) { } func Test_ProxyPublisherBandwidth(t *testing.T) { - CatchLogForTest(t) t.Parallel() server1 := NewProxyServerForTest(t, "DE") server2 := NewProxyServerForTest(t, "DE") @@ -1256,7 +1252,7 @@ func Test_ProxyPublisherBandwidth(t *testing.T) { server2, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pub1Id := PublicSessionId("the-publisher-1") @@ -1317,7 +1313,6 @@ func Test_ProxyPublisherBandwidth(t *testing.T) { } func Test_ProxyPublisherBandwidthOverload(t *testing.T) { - CatchLogForTest(t) t.Parallel() server1 := NewProxyServerForTest(t, "DE") server2 := NewProxyServerForTest(t, "DE") @@ -1326,7 +1321,7 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) { server2, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pub1Id := PublicSessionId("the-publisher-1") @@ -1390,7 +1385,6 @@ func Test_ProxyPublisherBandwidthOverload(t *testing.T) { } func Test_ProxyPublisherLoad(t *testing.T) { - CatchLogForTest(t) t.Parallel() server1 := NewProxyServerForTest(t, "DE") server2 := NewProxyServerForTest(t, "DE") @@ -1399,7 +1393,7 @@ func Test_ProxyPublisherLoad(t *testing.T) { server2, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pub1Id := PublicSessionId("the-publisher-1") @@ -1440,7 +1434,6 @@ func Test_ProxyPublisherLoad(t *testing.T) { } func Test_ProxyPublisherCountry(t *testing.T) { - CatchLogForTest(t) t.Parallel() serverDE := NewProxyServerForTest(t, "DE") serverUS := NewProxyServerForTest(t, "US") @@ -1449,7 +1442,7 @@ func Test_ProxyPublisherCountry(t *testing.T) { serverUS, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubDEId := PublicSessionId("the-publisher-de") @@ -1488,7 +1481,6 @@ func Test_ProxyPublisherCountry(t *testing.T) { } func Test_ProxyPublisherContinent(t *testing.T) { - CatchLogForTest(t) t.Parallel() serverDE := NewProxyServerForTest(t, "DE") serverUS := NewProxyServerForTest(t, "US") @@ -1497,7 +1489,7 @@ func Test_ProxyPublisherContinent(t *testing.T) { serverUS, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubDEId := PublicSessionId("the-publisher-de") @@ -1536,7 +1528,6 @@ func Test_ProxyPublisherContinent(t *testing.T) { } func Test_ProxySubscriberCountry(t *testing.T) { - CatchLogForTest(t) t.Parallel() serverDE := NewProxyServerForTest(t, "DE") serverUS := NewProxyServerForTest(t, "US") @@ -1545,7 +1536,7 @@ func Test_ProxySubscriberCountry(t *testing.T) { serverUS, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1580,7 +1571,6 @@ func Test_ProxySubscriberCountry(t *testing.T) { } func Test_ProxySubscriberContinent(t *testing.T) { - CatchLogForTest(t) t.Parallel() serverDE := NewProxyServerForTest(t, "DE") serverUS := NewProxyServerForTest(t, "US") @@ -1589,7 +1579,7 @@ func Test_ProxySubscriberContinent(t *testing.T) { serverUS, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1624,7 +1614,6 @@ func Test_ProxySubscriberContinent(t *testing.T) { } func Test_ProxySubscriberBandwidth(t *testing.T) { - CatchLogForTest(t) t.Parallel() serverDE := NewProxyServerForTest(t, "DE") serverUS := NewProxyServerForTest(t, "US") @@ -1633,7 +1622,7 @@ func Test_ProxySubscriberBandwidth(t *testing.T) { serverUS, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1688,7 +1677,6 @@ func Test_ProxySubscriberBandwidth(t *testing.T) { } func Test_ProxySubscriberBandwidthOverload(t *testing.T) { - CatchLogForTest(t) t.Parallel() serverDE := NewProxyServerForTest(t, "DE") serverUS := NewProxyServerForTest(t, "US") @@ -1697,7 +1685,7 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) { serverUS, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1806,7 +1794,6 @@ func (h *mockGrpcServerHub) CreateProxyToken(publisherId string) (string, error) } func Test_ProxyRemotePublisher(t *testing.T) { - CatchLogForTest(t) t.Parallel() etcd := NewEtcdForTest(t) @@ -1842,7 +1829,7 @@ func Test_ProxyRemotePublisher(t *testing.T) { }, 2) hub2.proxy.Store(mcu2) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1886,7 +1873,6 @@ func Test_ProxyRemotePublisher(t *testing.T) { } func Test_ProxyMultipleRemotePublisher(t *testing.T) { - CatchLogForTest(t) t.Parallel() etcd := NewEtcdForTest(t) @@ -1938,7 +1924,7 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) { }, 3) hub3.proxy.Store(mcu3) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -1993,7 +1979,6 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) { } func Test_ProxyRemotePublisherWait(t *testing.T) { - CatchLogForTest(t) t.Parallel() etcd := NewEtcdForTest(t) @@ -2029,7 +2014,7 @@ func Test_ProxyRemotePublisherWait(t *testing.T) { }, 2) hub2.proxy.Store(mcu2) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -2089,7 +2074,6 @@ func Test_ProxyRemotePublisherWait(t *testing.T) { } func Test_ProxyRemotePublisherTemporary(t *testing.T) { - CatchLogForTest(t) t.Parallel() etcd := NewEtcdForTest(t) @@ -2123,7 +2107,7 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) { }, 2) hub2.proxy.Store(mcu2) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -2198,7 +2182,6 @@ loop: } func Test_ProxyConnectToken(t *testing.T) { - CatchLogForTest(t) t.Parallel() etcd := NewEtcdForTest(t) @@ -2235,7 +2218,7 @@ func Test_ProxyConnectToken(t *testing.T) { }, 2) hub2.proxy.Store(mcu2) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -2279,7 +2262,6 @@ func Test_ProxyConnectToken(t *testing.T) { } func Test_ProxyPublisherToken(t *testing.T) { - CatchLogForTest(t) t.Parallel() etcd := NewEtcdForTest(t) @@ -2321,7 +2303,7 @@ func Test_ProxyPublisherToken(t *testing.T) { server1.servers = append(server1.servers, server2) server2.servers = append(server2.servers, server1) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -2365,7 +2347,6 @@ func Test_ProxyPublisherToken(t *testing.T) { } func Test_ProxyPublisherTimeout(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -2374,7 +2355,7 @@ func Test_ProxyPublisherTimeout(t *testing.T) { servers: []*TestProxyServerHandler{server}, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -2406,7 +2387,6 @@ func Test_ProxyPublisherTimeout(t *testing.T) { } func Test_ProxySubscriberTimeout(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -2415,7 +2395,7 @@ func Test_ProxySubscriberTimeout(t *testing.T) { servers: []*TestProxyServerHandler{server}, }, 0) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() pubId := PublicSessionId("the-publisher") @@ -2466,7 +2446,6 @@ func Test_ProxyReconnectAfter(t *testing.T) { } for _, reason := range reasons { t.Run(reason, func(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -2479,7 +2458,7 @@ func Test_ProxyReconnectAfter(t *testing.T) { require.Len(connections, 1) sessionId := connections[0].SessionId() - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() client := server.GetSingleClient() @@ -2507,7 +2486,6 @@ func Test_ProxyReconnectAfter(t *testing.T) { } func Test_ProxyReconnectAfterShutdown(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -2520,7 +2498,7 @@ func Test_ProxyReconnectAfterShutdown(t *testing.T) { require.Len(connections, 1) sessionId := connections[0].SessionId() - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() client := server.GetSingleClient() @@ -2547,7 +2525,6 @@ func Test_ProxyReconnectAfterShutdown(t *testing.T) { } func Test_ProxyResume(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -2560,7 +2537,7 @@ func Test_ProxyResume(t *testing.T) { require.Len(connections, 1) sessionId := connections[0].SessionId() - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() client := server.GetSingleClient() @@ -2580,7 +2557,6 @@ func Test_ProxyResume(t *testing.T) { } func Test_ProxyResumeFail(t *testing.T) { - CatchLogForTest(t) t.Parallel() require := require.New(t) assert := assert.New(t) @@ -2593,7 +2569,7 @@ func Test_ProxyResumeFail(t *testing.T) { require.Len(connections, 1) sessionId := connections[0].SessionId() - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() client := server.GetSingleClient() diff --git a/mcu_test.go b/mcu_test.go index 4023a90..378a2ed 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -25,10 +25,10 @@ import ( "context" "errors" "fmt" - "log" "maps" "sync" "sync/atomic" + "testing" "github.com/dlintw/goconf" @@ -41,6 +41,7 @@ var ( ) type TestMCU struct { + t *testing.T mu sync.Mutex // +checklocks:mu publishers map[PublicSessionId]*TestMCUPublisher @@ -51,11 +52,13 @@ type TestMCU struct { maxScreenBitrate api.AtomicBandwidth } -func NewTestMCU() (*TestMCU, error) { +func NewTestMCU(t *testing.T) *TestMCU { return &TestMCU{ + t: t, + publishers: make(map[PublicSessionId]*TestMCUPublisher), subscribers: make(map[string]*TestMCUSubscriber), - }, nil + } } func (m *TestMCU) GetBandwidthLimits() (api.Bandwidth, api.Bandwidth) { @@ -105,6 +108,7 @@ func (m *TestMCU) NewPublisher(ctx context.Context, listener McuListener, id Pub } pub := &TestMCUPublisher{ TestMCUClient: TestMCUClient{ + t: m.t, id: string(id), sid: sid, streamType: streamType, @@ -147,6 +151,7 @@ func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publi id := newRandomString(8) sub := &TestMCUSubscriber{ TestMCUClient: TestMCUClient{ + t: m.t, id: id, streamType: streamType, }, @@ -157,6 +162,7 @@ func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publi } type TestMCUClient struct { + t *testing.T closed atomic.Bool id string @@ -182,7 +188,8 @@ func (c *TestMCUClient) MaxBitrate() api.Bandwidth { func (c *TestMCUClient) Close(ctx context.Context) { if c.closed.CompareAndSwap(false, true) { - log.Printf("Close MCU client %s", c.id) + logger := NewLoggerForTest(c.t) + logger.Printf("Close MCU client %s", c.id) } } diff --git a/natsclient.go b/natsclient.go index ea0a364..1474155 100644 --- a/natsclient.go +++ b/natsclient.go @@ -26,7 +26,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "log" "net/url" "os" "os/signal" @@ -64,17 +63,19 @@ func GetEncodedSubject(prefix string, suffix string) string { } type natsClient struct { - conn *nats.Conn + logger Logger + conn *nats.Conn } -func NewNatsClient(url string, options ...nats.Option) (NatsClient, error) { +func NewNatsClient(ctx context.Context, url string, options ...nats.Option) (NatsClient, error) { + logger := LoggerFromContext(ctx) if url == ":loopback:" { - log.Printf("WARNING: events url %s is deprecated, please use %s instead", url, NatsLoopbackUrl) + logger.Printf("WARNING: events url %s is deprecated, please use %s instead", url, NatsLoopbackUrl) url = NatsLoopbackUrl } if url == NatsLoopbackUrl { - log.Println("Using internal NATS loopback client") - return NewLoopbackNatsClient() + logger.Println("Using internal NATS loopback client") + return NewLoopbackNatsClient(logger) } backoff, err := NewExponentialBackoff(initialConnectInterval, maxConnectInterval) @@ -82,7 +83,9 @@ func NewNatsClient(url string, options ...nats.Option) (NatsClient, error) { return nil, err } - client := &natsClient{} + client := &natsClient{ + logger: logger, + } options = append([]nats.Option{ nats.ClosedHandler(client.onClosed), @@ -92,12 +95,12 @@ func NewNatsClient(url string, options ...nats.Option) (NatsClient, error) { }, options...) client.conn, err = nats.Connect(url, options...) - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + ctx, stop := signal.NotifyContext(ctx, os.Interrupt) defer stop() // The initial connect must succeed, so we retry in the case of an error. for err != nil { - log.Printf("Could not create connection (%s), will retry in %s", err, backoff.NextWait()) + logger.Printf("Could not create connection (%s), will retry in %s", err, backoff.NextWait()) backoff.Wait(ctx) if ctx.Err() != nil { return nil, fmt.Errorf("interrupted") @@ -105,7 +108,7 @@ func NewNatsClient(url string, options ...nats.Option) (NatsClient, error) { client.conn, err = nats.Connect(url) } - log.Printf("Connection established to %s (%s)", removeURLCredentials(client.conn.ConnectedUrl()), client.conn.ConnectedServerId()) + logger.Printf("Connection established to %s (%s)", removeURLCredentials(client.conn.ConnectedUrl()), client.conn.ConnectedServerId()) return client, nil } @@ -114,15 +117,15 @@ func (c *natsClient) Close() { } func (c *natsClient) onClosed(conn *nats.Conn) { - log.Println("NATS client closed", conn.LastError()) + c.logger.Println("NATS client closed", conn.LastError()) } func (c *natsClient) onDisconnected(conn *nats.Conn) { - log.Println("NATS client disconnected") + c.logger.Println("NATS client disconnected") } func (c *natsClient) onReconnected(conn *nats.Conn) { - log.Printf("NATS client reconnected to %s (%s)", conn.ConnectedUrl(), conn.ConnectedServerId()) + c.logger.Printf("NATS client reconnected to %s (%s)", conn.ConnectedUrl(), conn.ConnectedServerId()) } func (c *natsClient) Subscribe(subject string, ch chan *nats.Msg) (NatsSubscription, error) { diff --git a/natsclient_loopback.go b/natsclient_loopback.go index 6b8f56a..1421d9a 100644 --- a/natsclient_loopback.go +++ b/natsclient_loopback.go @@ -24,7 +24,6 @@ package signaling import ( "container/list" "encoding/json" - "log" "strings" "sync" @@ -32,6 +31,8 @@ import ( ) type LoopbackNatsClient struct { + logger Logger + mu sync.Mutex // +checklocks:mu subscriptions map[string]map[*loopbackNatsSubscription]bool @@ -42,8 +43,10 @@ type LoopbackNatsClient struct { incoming list.List } -func NewLoopbackNatsClient() (NatsClient, error) { +func NewLoopbackNatsClient(logger Logger) (NatsClient, error) { client := &LoopbackNatsClient{ + logger: logger, + subscriptions: make(map[string]map[*loopbackNatsSubscription]bool), } client.wakeup.L = &client.mu @@ -85,7 +88,7 @@ func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) { select { case ch <- msg: default: - log.Printf("Slow consumer %s, dropping message", msg.Subject) + c.logger.Printf("Slow consumer %s, dropping message", msg.Subject) } } } diff --git a/natsclient_loopback_test.go b/natsclient_loopback_test.go index d6cf5de..6cf6ed7 100644 --- a/natsclient_loopback_test.go +++ b/natsclient_loopback_test.go @@ -52,7 +52,8 @@ func (c *LoopbackNatsClient) waitForSubscriptionsEmpty(ctx context.Context, t *t } func CreateLoopbackNatsClientForTest(t *testing.T) NatsClient { - result, err := NewLoopbackNatsClient() + logger := NewLoggerForTest(t) + result, err := NewLoopbackNatsClient(logger) require.NoError(t, err) t.Cleanup(func() { result.Close() diff --git a/natsclient_test.go b/natsclient_test.go index 473a543..cc3b22b 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -55,7 +55,9 @@ func startLocalNatsServerPort(t *testing.T, port int) (*server.Server, int) { func CreateLocalNatsClientForTest(t *testing.T, options ...nats.Option) (*server.Server, int, NatsClient) { t.Helper() server, port := startLocalNatsServer(t) - result, err := NewNatsClient(server.ClientURL(), options...) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) + result, err := NewNatsClient(ctx, server.ClientURL(), options...) require.NoError(t, err) t.Cleanup(func() { result.Close() @@ -106,7 +108,6 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { } func TestNatsClient_Subscribe(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { _, _, client := CreateLocalNatsClientForTest(t) @@ -121,7 +122,6 @@ func testNatsClient_PublishAfterClose(t *testing.T, client NatsClient) { } func TestNatsClient_PublishAfterClose(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { _, _, client := CreateLocalNatsClientForTest(t) @@ -138,7 +138,6 @@ func testNatsClient_SubscribeAfterClose(t *testing.T, client NatsClient) { } func TestNatsClient_SubscribeAfterClose(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { _, _, client := CreateLocalNatsClientForTest(t) @@ -161,7 +160,6 @@ func testNatsClient_BadSubjects(t *testing.T, client NatsClient) { } func TestNatsClient_BadSubjects(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { _, _, client := CreateLocalNatsClientForTest(t) @@ -170,7 +168,6 @@ func TestNatsClient_BadSubjects(t *testing.T) { } func TestNatsClient_MaxReconnects(t *testing.T) { - CatchLogForTest(t) ensureNoGoroutinesLeak(t, func(t *testing.T) { assert := assert.New(t) require := require.New(t) diff --git a/proxy/main.go b/proxy/main.go index f768eb5..acc024e 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -22,6 +22,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -64,28 +65,33 @@ func main() { } sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt) signal.Notify(sigChan, syscall.SIGHUP) signal.Notify(sigChan, syscall.SIGUSR1) - log.Printf("Starting up version %s/%s as pid %d", version, runtime.Version(), os.Getpid()) + stopCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + + logger := log.Default() + stopCtx = signaling.NewLoggerContext(stopCtx, logger) + + logger.Printf("Starting up version %s/%s as pid %d", version, runtime.Version(), os.Getpid()) config, err := goconf.ReadConfigFile(*configFlag) if err != nil { - log.Fatal("Could not read configuration: ", err) + logger.Fatal("Could not read configuration: ", err) } - log.Printf("Using a maximum of %d CPUs", runtime.GOMAXPROCS(0)) + logger.Printf("Using a maximum of %d CPUs", runtime.GOMAXPROCS(0)) r := mux.NewRouter() - proxy, err := NewProxyServer(r, version, config) + proxy, err := NewProxyServer(stopCtx, r, version, config) if err != nil { - log.Fatal(err) + logger.Fatal(err) } if err := proxy.Start(config); err != nil { - log.Fatal(err) + logger.Fatal(err) } defer proxy.Stop() @@ -101,10 +107,10 @@ func main() { for address := range signaling.SplitEntries(addr, " ") { go func(address string) { - log.Println("Listening on", address) + logger.Println("Listening on", address) listener, err := net.Listen("tcp", address) if err != nil { - log.Fatal("Could not start listening: ", err) + logger.Fatal("Could not start listening: ", err) } srv := &http.Server{ Handler: r, @@ -114,7 +120,7 @@ func main() { WriteTimeout: time.Duration(writeTimeout) * time.Second, } if err := srv.Serve(listener); err != nil { - log.Fatal("Could not start server: ", err) + logger.Fatal("Could not start server: ", err) } }(address) } @@ -123,24 +129,24 @@ func main() { loop: for { select { + case <-stopCtx.Done(): + logger.Println("Interrupted") + break loop case sig := <-sigChan: switch sig { - case os.Interrupt: - log.Println("Interrupted") - break loop case syscall.SIGHUP: - log.Printf("Received SIGHUP, reloading %s", *configFlag) + logger.Printf("Received SIGHUP, reloading %s", *configFlag) if config, err := goconf.ReadConfigFile(*configFlag); err != nil { - log.Printf("Could not read configuration from %s: %s", *configFlag, err) + logger.Printf("Could not read configuration from %s: %s", *configFlag, err) } else { proxy.Reload(config) } case syscall.SIGUSR1: - log.Printf("Received SIGUSR1, scheduling server to shutdown") + logger.Printf("Received SIGUSR1, scheduling server to shutdown") proxy.ScheduleShutdown() } case <-proxy.ShutdownChannel(): - log.Printf("All clients disconnected, shutting down") + logger.Printf("All clients disconnected, shutting down") break loop } } diff --git a/proxy/proxy_remote.go b/proxy/proxy_remote.go index 6eaad05..9a73122 100644 --- a/proxy/proxy_remote.go +++ b/proxy/proxy_remote.go @@ -27,7 +27,6 @@ import ( "crypto/tls" "encoding/json" "errors" - "log" "math/rand/v2" "net" "net/http" @@ -62,9 +61,10 @@ var ( ) type RemoteConnection struct { - mu sync.Mutex - p *ProxyServer - url *url.URL + logger signaling.Logger + mu sync.Mutex + p *ProxyServer + url *url.URL // +checklocks:mu conn *websocket.Conn closeCtx context.Context @@ -102,6 +102,7 @@ func NewRemoteConnection(p *ProxyServer, proxyUrl string, tokenId string, tokenK closeCtx, closeFunc := context.WithCancel(context.Background()) result := &RemoteConnection{ + logger: p.logger, p: p, url: u, closeCtx: closeCtx, @@ -135,7 +136,7 @@ func (c *RemoteConnection) SessionId() signaling.PublicSessionId { func (c *RemoteConnection) reconnect() { u, err := c.url.Parse("proxy") if err != nil { - log.Printf("Could not resolve url to proxy at %s: %s", c, err) + c.logger.Printf("Could not resolve url to proxy at %s: %s", c, err) c.scheduleReconnect() return } @@ -153,12 +154,12 @@ func (c *RemoteConnection) reconnect() { conn, _, err := dialer.DialContext(context.TODO(), u.String(), nil) if err != nil { - log.Printf("Error connecting to proxy at %s: %s", c, err) + c.logger.Printf("Error connecting to proxy at %s: %s", c, err) c.scheduleReconnect() return } - log.Printf("Connected to %s", c) + c.logger.Printf("Connected to %s", c) c.mu.Lock() c.connectedSince = time.Now() @@ -180,7 +181,7 @@ func (c *RemoteConnection) sendReconnectHello() bool { defer c.mu.Unlock() if err := c.sendHello(c.closeCtx); err != nil { - log.Printf("Error sending hello request to proxy at %s: %s", c, err) + c.logger.Printf("Error sending hello request to proxy at %s: %s", c, err) return false } @@ -197,7 +198,7 @@ func (c *RemoteConnection) scheduleReconnect() { // +checklocks:c.mu func (c *RemoteConnection) scheduleReconnectLocked() { if err := c.sendCloseLocked(); err != nil && err != ErrNotConnected { - log.Printf("Could not send close message to %s: %s", c, err) + c.logger.Printf("Could not send close message to %s: %s", c, err) } c.closeLocked() @@ -366,20 +367,20 @@ func (c *RemoteConnection) readPump(conn *websocket.Conn) { websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { if !errors.Is(err, net.ErrClosed) || c.closeCtx.Err() == nil { - log.Printf("Error reading from %s: %v", c, err) + c.logger.Printf("Error reading from %s: %v", c, err) } } break } if msgType != websocket.TextMessage { - log.Printf("unexpected message type %q (%s)", msgType, string(msg)) + c.logger.Printf("unexpected message type %q (%s)", msgType, string(msg)) continue } var message signaling.ProxyServerMessage if err := json.Unmarshal(msg, &message); err != nil { - log.Printf("could not decode message %s: %s", string(msg), err) + c.logger.Printf("could not decode message %s: %s", string(msg), err) continue } @@ -406,7 +407,7 @@ func (c *RemoteConnection) sendPing() bool { msg := strconv.FormatInt(now.UnixNano(), 10) c.conn.SetWriteDeadline(now.Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { - log.Printf("Could not send ping to proxy at %s: %v", c, err) + c.logger.Printf("Could not send ping to proxy at %s: %v", c, err) go c.scheduleReconnect() return false } @@ -441,16 +442,16 @@ func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { switch msg.Type { case "error": if msg.Error.Code == "no_such_session" { - log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c) + c.logger.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c) c.sessionId = "" if err := c.sendHello(c.closeCtx); err != nil { - log.Printf("Could not send hello request to %s: %s", c, err) + c.logger.Printf("Could not send hello request to %s: %s", c, err) c.scheduleReconnectLocked() } return } - log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) + c.logger.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) c.scheduleReconnectLocked() case "hello": resumed := c.sessionId == msg.Hello.SessionId @@ -459,16 +460,16 @@ func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { country := "" if msg.Hello.Server != nil { if country = msg.Hello.Server.Country; country != "" && !signaling.IsValidCountry(country) { - log.Printf("Proxy %s sent invalid country %s in hello response", c, country) + c.logger.Printf("Proxy %s sent invalid country %s in hello response", c, country) country = "" } } if resumed { - log.Printf("Resumed session %s on %s", c.sessionId, c) + c.logger.Printf("Resumed session %s on %s", c.sessionId, c) } else if country != "" { - log.Printf("Received session %s from %s (in %s)", c.sessionId, c, country) + c.logger.Printf("Received session %s from %s (in %s)", c.sessionId, c, country) } else { - log.Printf("Received session %s from %s", c.sessionId, c) + c.logger.Printf("Received session %s from %s", c.sessionId, c) } pending := c.pendingMessages @@ -479,11 +480,11 @@ func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { } if err := c.sendMessageLocked(c.closeCtx, m); err != nil { - log.Printf("Could not send pending message %+v to %s: %s", m, c, err) + c.logger.Printf("Could not send pending message %+v to %s: %s", m, c, err) } } default: - log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c) + c.logger.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c) c.scheduleReconnectLocked() } } @@ -516,7 +517,7 @@ func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) { case "event": c.processEvent(msg) case "bye": - log.Printf("Connection to %s was closed: %s", c, msg.Bye.Reason) + c.logger.Printf("Connection to %s was closed: %s", c, msg.Bye.Reason) if msg.Bye.Reason == "session_expired" { // Don't try to resume expired session. c.mu.Lock() @@ -525,7 +526,7 @@ func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) { } c.scheduleReconnect() default: - log.Printf("Received unsupported message %+v from %s", msg, c) + c.logger.Printf("Received unsupported message %+v from %s", msg, c) } } @@ -534,10 +535,10 @@ func (c *RemoteConnection) processEvent(msg *signaling.ProxyServerMessage) { case "update-load": // Ignore case "publisher-closed": - log.Printf("Remote publisher %s was closed on %s", msg.Event.ClientId, c) + c.logger.Printf("Remote publisher %s was closed on %s", msg.Event.ClientId, c) c.p.RemotePublisherDeleted(signaling.PublicSessionId(msg.Event.ClientId)) default: - log.Printf("Received unsupported event %+v from %s", msg, c) + c.logger.Printf("Received unsupported event %+v from %s", msg, c) } } diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index ad932db..72ed239 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -30,7 +30,6 @@ import ( "errors" "fmt" "io" - "log" "net" "net/http" "net/http/pprof" @@ -111,6 +110,7 @@ type ProxyServer struct { welcomeMsg *signaling.WelcomeServerMessage config *goconf.ConfigFile mcuTimeout time.Duration + logger signaling.Logger url string mcu signaling.Mcu @@ -189,16 +189,16 @@ func GetLocalIP() (string, error) { return "", nil } -func getTargetBandwidths(config *goconf.ConfigFile) (api.Bandwidth, api.Bandwidth) { +func getTargetBandwidths(logger signaling.Logger, config *goconf.ConfigFile) (api.Bandwidth, api.Bandwidth) { maxIncomingValue, _ := config.GetInt("bandwidth", "incoming") if maxIncomingValue < 0 { maxIncomingValue = 0 } maxIncoming := api.BandwidthFromMegabits(uint64(maxIncomingValue)) if maxIncoming > 0 { - log.Printf("Target bandwidth for incoming streams: %s", maxIncoming) + logger.Printf("Target bandwidth for incoming streams: %s", maxIncoming) } else { - log.Printf("Target bandwidth for incoming streams: unlimited") + logger.Printf("Target bandwidth for incoming streams: unlimited") } maxOutgoingValue, _ := config.GetInt("bandwidth", "outgoing") @@ -207,15 +207,16 @@ func getTargetBandwidths(config *goconf.ConfigFile) (api.Bandwidth, api.Bandwidt } maxOutgoing := api.BandwidthFromMegabits(uint64(maxOutgoingValue)) if maxOutgoing > 0 { - log.Printf("Target bandwidth for outgoing streams: %s", maxOutgoing) + logger.Printf("Target bandwidth for outgoing streams: %s", maxOutgoing) } else { - log.Printf("Target bandwidth for outgoing streams: unlimited") + logger.Printf("Target bandwidth for outgoing streams: unlimited") } return maxIncoming, maxOutgoing } -func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*ProxyServer, error) { +func NewProxyServer(ctx context.Context, r *mux.Router, version string, config *goconf.ConfigFile) (*ProxyServer, error) { + logger := signaling.LoggerFromContext(ctx) hashKey := make([]byte, 64) if _, err := rand.Read(hashKey); err != nil { return nil, fmt.Errorf("could not generate random hash key: %s", err) @@ -235,9 +236,9 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* switch tokenType { case TokenTypeEtcd: - tokens, err = NewProxyTokensEtcd(config) + tokens, err = NewProxyTokensEtcd(logger, config) case TokenTypeStatic: - tokens, err = NewProxyTokensStatic(config) + tokens, err = NewProxyTokensStatic(logger, config) default: return nil, fmt.Errorf("unsupported token type configured: %s", tokenType) } @@ -252,10 +253,10 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* } if !statsAllowedIps.Empty() { - log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) + logger.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) } else { statsAllowedIps = signaling.DefaultAllowedIps() - log.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) + logger.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) } trustedProxies, _ := config.GetString("app", "trustedproxies") @@ -265,20 +266,20 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* } if !trustedProxiesIps.Empty() { - log.Printf("Trusted proxies: %s", trustedProxiesIps) + logger.Printf("Trusted proxies: %s", trustedProxiesIps) } else { trustedProxiesIps = signaling.DefaultTrustedProxies - log.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) + logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } country, _ := config.GetString("app", "country") country = strings.ToUpper(country) if signaling.IsValidCountry(country) { - log.Printf("Sending %s as country information", country) + logger.Printf("Sending %s as country information", country) } else if country != "" { return nil, fmt.Errorf("invalid country: %s", country) } else { - log.Printf("Not sending country information") + logger.Printf("Not sending country information") } welcome := map[string]string{ @@ -308,7 +309,7 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* if err != nil { return nil, fmt.Errorf("could not parse private key from %s: %s", tokenKeyFilename, err) } - log.Printf("Using \"%s\" as token id for remote streams", tokenId) + logger.Printf("Using \"%s\" as token id for remote streams", tokenId) remoteHostname, _ = config.GetString("app", "hostname") if remoteHostname == "" { @@ -318,23 +319,23 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* } } if remoteHostname == "" { - log.Printf("WARNING: Could not determine hostname for remote streams, will be disabled. Please configure manually.") + logger.Printf("WARNING: Could not determine hostname for remote streams, will be disabled. Please configure manually.") } else { - log.Printf("Using \"%s\" as hostname for remote streams", remoteHostname) + logger.Printf("Using \"%s\" as hostname for remote streams", remoteHostname) } skipverify, _ := config.GetBool("backend", "skipverify") if skipverify { - log.Println("WARNING: Remote stream requests verification is disabled!") + logger.Println("WARNING: Remote stream requests verification is disabled!") remoteTlsConfig = &tls.Config{ InsecureSkipVerify: skipverify, } } } else { - log.Printf("No token id configured, remote streams will be disabled") + logger.Printf("No token id configured, remote streams will be disabled") } - maxIncoming, maxOutgoing := getTargetBandwidths(config) + maxIncoming, maxOutgoing := getTargetBandwidths(logger, config) mcuTimeoutSeconds, _ := config.GetInt("mcu", "timeout") if mcuTimeoutSeconds <= 0 { @@ -353,6 +354,7 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* }, config: config, mcuTimeout: mcuTimeout, + logger: logger, shutdownChannel: make(chan struct{}), @@ -389,7 +391,7 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (* statsLoadCurrent.Set(0) if debug, _ := config.GetBool("app", "debug"); debug { - log.Println("Installing debug handlers in \"/debug/pprof\"") + logger.Println("Installing debug handlers in \"/debug/pprof\"") s := r.PathPrefix("/debug/pprof").Subrouter() s.HandleFunc("", result.setCommonHeaders(result.validateStatsRequest(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/debug/pprof/", http.StatusTemporaryRedirect) @@ -455,14 +457,14 @@ func (s *ProxyServer) Start(config *goconf.ConfigFile) error { mcu.SetOnDisconnected(s.onMcuDisconnected) err = mcu.Start(ctx) if err != nil { - log.Printf("Could not create %s MCU at %s: %s", mcuType, s.url, err) + s.logger.Printf("Could not create %s MCU at %s: %s", mcuType, s.url, err) } } if err == nil { break } - log.Printf("Could not initialize %s MCU at %s (%s) will retry in %s", mcuType, s.url, err, backoff.NextWait()) + s.logger.Printf("Could not initialize %s MCU at %s (%s) will retry in %s", mcuType, s.url, err, backoff.NextWait()) backoff.Wait(ctx) if ctx.Err() != nil { return fmt.Errorf("cancelled") @@ -572,7 +574,7 @@ func (s *ProxyServer) expireSessions() { continue } - log.Printf("Delete expired session %s", session.PublicId()) + s.logger.Printf("Delete expired session %s", session.PublicId()) s.deleteSessionLocked(session.Sid()) } } @@ -616,30 +618,30 @@ func (s *ProxyServer) Reload(config *goconf.ConfigFile) { statsAllowed, _ := config.GetString("stats", "allowed_ips") if statsAllowedIps, err := signaling.ParseAllowedIps(statsAllowed); err == nil { if !statsAllowedIps.Empty() { - log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) + s.logger.Printf("Only allowing access to the stats endpoint from %s", statsAllowed) } else { statsAllowedIps = signaling.DefaultAllowedIps() - log.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) + s.logger.Printf("No IPs configured for the stats endpoint, only allowing access from %s", statsAllowedIps) } s.statsAllowedIps.Store(statsAllowedIps) } else { - log.Printf("Error parsing allowed stats ips from \"%s\": %s", statsAllowedIps, err) + s.logger.Printf("Error parsing allowed stats ips from \"%s\": %s", statsAllowedIps, err) } trustedProxies, _ := config.GetString("app", "trustedproxies") if trustedProxiesIps, err := signaling.ParseAllowedIps(trustedProxies); err == nil { if !trustedProxiesIps.Empty() { - log.Printf("Trusted proxies: %s", trustedProxiesIps) + s.logger.Printf("Trusted proxies: %s", trustedProxiesIps) } else { trustedProxiesIps = signaling.DefaultTrustedProxies - log.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) + s.logger.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps) } s.trustedProxies.Store(trustedProxiesIps) } else { - log.Printf("Error parsing trusted proxies from \"%s\": %s", trustedProxies, err) + s.logger.Printf("Error parsing trusted proxies from \"%s\": %s", trustedProxies, err) } - maxIncoming, maxOutgoing := getTargetBandwidths(config) + maxIncoming, maxOutgoing := getTargetBandwidths(s.logger, config) oldIncoming := s.maxIncoming.Swap(maxIncoming) oldOutgoing := s.maxOutgoing.Swap(maxOutgoing) if oldIncoming != maxIncoming || oldOutgoing != maxOutgoing { @@ -672,19 +674,20 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { header.Set("X-Spreed-Signaling-Features", strings.Join(s.welcomeMsg.Features, ", ")) conn, err := s.upgrader.Upgrade(w, r, header) if err != nil { - log.Printf("Could not upgrade request from %s: %s", addr, err) + s.logger.Printf("Could not upgrade request from %s: %s", addr, err) return } + ctx := signaling.NewLoggerContext(r.Context(), s.logger) if conn.Subprotocol() == signaling.JanusEventsSubprotocol { agent := r.Header.Get("User-Agent") - signaling.RunJanusEventsHandler(r.Context(), s.mcu, conn, addr, agent) + signaling.RunJanusEventsHandler(ctx, s.mcu, conn, addr, agent) return } - client, err := NewProxyClient(r.Context(), s, conn, addr) + client, err := NewProxyClient(ctx, s, conn, addr) if err != nil { - log.Printf("Could not create client for %s: %s", addr, err) + s.logger.Printf("Could not create client for %s: %s", addr, err) return } @@ -693,11 +696,11 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { } func (s *ProxyServer) clientClosed(client *signaling.Client) { - log.Printf("Connection from %s closed", client.RemoteAddr()) + s.logger.Printf("Connection from %s closed", client.RemoteAddr()) } func (s *ProxyServer) onMcuConnected() { - log.Printf("Connection to %s established", s.url) + s.logger.Printf("Connection to %s established", s.url) msg := &signaling.ProxyServerMessage{ Type: "event", Event: &signaling.EventProxyServerMessage{ @@ -716,7 +719,7 @@ func (s *ProxyServer) onMcuDisconnected() { return } - log.Printf("Connection to %s lost", s.url) + s.logger.Printf("Connection to %s lost", s.url) msg := &signaling.ProxyServerMessage{ Type: "event", Event: &signaling.EventProxyServerMessage{ @@ -747,14 +750,14 @@ func (s *ProxyServer) sendShutdownScheduled(session *ProxySession) { func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { if proxyDebugMessages { - log.Printf("Message: %s", string(data)) + s.logger.Printf("Message: %s", string(data)) } var message signaling.ProxyClientMessage if err := message.UnmarshalJSON(data); err != nil { if session := client.GetSession(); session != nil { - log.Printf("Error decoding message from client %s: %v", session.PublicId(), err) + s.logger.Printf("Error decoding message from client %s: %v", session.PublicId(), err) } else { - log.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) + s.logger.Printf("Error decoding message from %s: %v", client.RemoteAddr(), err) } client.SendError(signaling.InvalidFormat) return @@ -762,9 +765,9 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { if err := message.CheckValid(); err != nil { if session := client.GetSession(); session != nil { - log.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) + s.logger.Printf("Invalid message %+v from client %s: %v", message, session.PublicId(), err) } else { - log.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) + s.logger.Printf("Invalid message %+v from %s: %v", message, client.RemoteAddr(), err) } client.SendMessage(message.NewErrorServerMessage(signaling.InvalidFormat)) return @@ -788,7 +791,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { return } - log.Printf("Resumed session %s", session.PublicId()) + s.logger.Printf("Resumed session %s", session.PublicId()) session.MarkUsed() if s.shutdownScheduled.Load() { s.sendShutdownScheduled(session) @@ -945,7 +948,7 @@ func (s *ProxyServer) addRemotePublisher(publisher *proxyRemotePublisher) { } publishers[publisher] = true - log.Printf("Add remote publisher to %s", publisher.remoteUrl) + s.logger.Printf("Add remote publisher to %s", publisher.remoteUrl) } func (s *ProxyServer) hasRemotePublishers() bool { @@ -959,7 +962,7 @@ func (s *ProxyServer) removeRemotePublisher(publisher *proxyRemotePublisher) { s.remoteConnectionsLock.Lock() defer s.remoteConnectionsLock.Unlock() - log.Printf("Removing remote publisher to %s", publisher.remoteUrl) + s.logger.Printf("Removing remote publisher to %s", publisher.remoteUrl) publishers, found := s.remotePublishers[publisher.remoteUrl] if !found { return @@ -974,9 +977,9 @@ func (s *ProxyServer) removeRemotePublisher(publisher *proxyRemotePublisher) { if conn, found := s.remoteConnections[publisher.remoteUrl]; found { delete(s.remoteConnections, publisher.remoteUrl) if err := conn.Close(); err != nil { - log.Printf("Error closing remote connection to %s: %s", publisher.remoteUrl, err) + s.logger.Printf("Error closing remote connection to %s: %s", publisher.remoteUrl, err) } else { - log.Printf("Remote connection to %s closed", publisher.remoteUrl) + s.logger.Printf("Remote connection to %s closed", publisher.remoteUrl) } } } @@ -1006,16 +1009,16 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } publisher, err := s.mcu.NewPublisher(ctx2, session, signaling.PublicSessionId(id), cmd.Sid, cmd.StreamType, *settings, &emptyInitiator{}) if err == context.DeadlineExceeded { - log.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId()) + s.logger.Printf("Timeout while creating %s publisher %s for %s", cmd.StreamType, id, session.PublicId()) session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingPublisher)) return } else if err != nil { - log.Printf("Error while creating %s publisher %s for %s: %s", cmd.StreamType, id, session.PublicId(), err) + s.logger.Printf("Error while creating %s publisher %s for %s: %s", cmd.StreamType, id, session.PublicId(), err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return } - log.Printf("Created %s publisher %s as %s for %s", cmd.StreamType, publisher.Id(), id, session.PublicId()) + s.logger.Printf("Created %s publisher %s as %s for %s", cmd.StreamType, publisher.Id(), id, session.PublicId()) session.StorePublisher(ctx, id, publisher) s.StoreClient(id, publisher) @@ -1038,7 +1041,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s handleCreateError := func(err error) { if err == context.DeadlineExceeded { - log.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId()) + s.logger.Printf("Timeout while creating %s subscriber on %s for %s", cmd.StreamType, publisherId, session.PublicId()) session.sendMessage(message.NewErrorServerMessage(TimeoutCreatingSubscriber)) return } else if errors.Is(err, signaling.ErrRemoteStreamsNotSupported) { @@ -1046,7 +1049,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } - log.Printf("Error while creating %s subscriber on %s for %s: %s", cmd.StreamType, publisherId, session.PublicId(), err) + s.logger.Printf("Error while creating %s subscriber on %s for %s: %s", cmd.StreamType, publisherId, session.PublicId(), err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) } @@ -1080,7 +1083,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s subCtx, cancel := context.WithTimeout(ctx, remotePublisherTimeout) defer cancel() - log.Printf("Creating remote subscriber for %s on %s", publisherId, cmd.RemoteUrl) + s.logger.Printf("Creating remote subscriber for %s on %s", publisherId, cmd.RemoteUrl) controller := &proxyRemotePublisher{ proxy: s, @@ -1107,7 +1110,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } - log.Printf("Created remote %s subscriber %s as %s for %s on %s", cmd.StreamType, subscriber.Id(), id, session.PublicId(), cmd.RemoteUrl) + s.logger.Printf("Created remote %s subscriber %s as %s for %s on %s", cmd.StreamType, subscriber.Id(), id, session.PublicId(), cmd.RemoteUrl) } else { ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout) defer cancel() @@ -1118,7 +1121,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } - log.Printf("Created %s subscriber %s as %s for %s", cmd.StreamType, subscriber.Id(), id, session.PublicId()) + s.logger.Printf("Created %s subscriber %s as %s for %s", cmd.StreamType, subscriber.Id(), id, session.PublicId()) } session.StoreSubscriber(ctx, id, subscriber) @@ -1158,7 +1161,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } go func() { - log.Printf("Closing %s publisher %s as %s", client.StreamType(), client.Id(), cmd.ClientId) + s.logger.Printf("Closing %s publisher %s as %s", client.StreamType(), client.Id(), cmd.ClientId) client.Close(context.Background()) }() @@ -1193,7 +1196,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } go func() { - log.Printf("Closing %s subscriber %s as %s", client.StreamType(), client.Id(), cmd.ClientId) + s.logger.Printf("Closing %s subscriber %s as %s", client.StreamType(), client.Id(), cmd.ClientId) client.Close(context.Background()) }() @@ -1224,7 +1227,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s if err := publisher.PublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { var je *janus.ErrorMsg if !errors.As(err, &je) || je.Err.Code != signaling.JANUS_VIDEOROOM_ERROR_ID_EXISTS { - log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) + s.logger.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return } @@ -1233,7 +1236,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s defer cancel() if err := publisher.UnpublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { - log.Printf("Error unpublishing old %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) + s.logger.Printf("Error unpublishing old %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return } @@ -1242,7 +1245,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s defer cancel() if err := publisher.PublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { - log.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) + s.logger.Printf("Error publishing %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return } @@ -1274,7 +1277,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s defer cancel() if err := publisher.UnpublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { - log.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, err) + s.logger.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return } @@ -1304,7 +1307,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s streams, err := publisher.GetStreams(ctx) if err != nil { - log.Printf("Could not get streams of publisher %s: %s", publisher.Id(), err) + s.logger.Printf("Could not get streams of publisher %s: %s", publisher.Id(), err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return } @@ -1319,7 +1322,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } session.sendMessage(response) default: - log.Printf("Unsupported command %+v", message.Command) + s.logger.Printf("Unsupported command %+v", message.Command) session.sendMessage(message.NewErrorServerMessage(UnsupportedCommand)) } } @@ -1374,7 +1377,7 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s } if err := mcuData.CheckValid(); err != nil { - log.Printf("Received invalid payload %+v for %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err) + s.logger.Printf("Received invalid payload %+v for %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err) session.sendMessage(message.NewErrorServerMessage(UnsupportedPayload)) return } @@ -1390,7 +1393,7 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s } if err != nil { - log.Printf("Error sending %+v to %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err) + s.logger.Printf("Error sending %+v to %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err) responseMsg = message.NewWrappedErrorServerMessage(err) } else { responseMsg = &signaling.ProxyServerMessage{ @@ -1409,7 +1412,7 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s } func (s *ProxyServer) processBye(ctx context.Context, client *ProxyClient, session *ProxySession, message *signaling.ProxyClientMessage) { - log.Printf("Closing session %s", session.PublicId()) + s.logger.Printf("Closing session %s", session.PublicId()) s.DeleteSession(session.Sid()) } @@ -1418,27 +1421,27 @@ func (s *ProxyServer) parseToken(tokenValue string) (*signaling.TokenClaims, str token, err := jwt.ParseWithClaims(tokenValue, &signaling.TokenClaims{}, func(token *jwt.Token) (any, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - log.Printf("Unexpected signing method: %v", token.Header["alg"]) + s.logger.Printf("Unexpected signing method: %v", token.Header["alg"]) reason = "unsupported-signing-method" return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } claims, ok := token.Claims.(*signaling.TokenClaims) if !ok { - log.Printf("Unsupported claims type: %+v", token.Claims) + s.logger.Printf("Unsupported claims type: %+v", token.Claims) reason = "unsupported-claims" return nil, fmt.Errorf("unsupported claims type") } tokenKey, err := s.tokens.Get(claims.Issuer) if err != nil { - log.Printf("Could not get token for %s: %s", claims.Issuer, err) + s.logger.Printf("Could not get token for %s: %s", claims.Issuer, err) reason = "missing-issuer" return nil, err } if tokenKey == nil || tokenKey.key == nil { - log.Printf("Issuer %s is not supported", claims.Issuer) + s.logger.Printf("Issuer %s is not supported", claims.Issuer) reason = "unsupported-issuer" return nil, fmt.Errorf("no key found for issuer") } @@ -1475,7 +1478,7 @@ func (s *ProxyServer) parseToken(tokenValue string) (*signaling.TokenClaims, str func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*ProxySession, error) { if proxyDebugMessages { - log.Printf("Hello: %+v", hello) + s.logger.Printf("Hello: %+v", hello) } claims, reason, err := s.parseToken(hello.Token) @@ -1499,7 +1502,7 @@ func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*Pro return nil, err } - log.Printf("Created session %s for %+v", encoded, claims) + s.logger.Printf("Created session %s for %+v", encoded, claims) session := NewProxySession(s, sid, encoded) s.StoreSession(sid, session) statsSessionsCurrent.Inc() @@ -1669,7 +1672,7 @@ func (s *ProxyServer) statsHandler(w http.ResponseWriter, r *http.Request) { stats := s.getStats() statsData, err := json.MarshalIndent(stats, "", " ") if err != nil { - log.Printf("Could not serialize stats %+v: %s", stats, err) + s.logger.Printf("Could not serialize stats %+v: %s", stats, err) http.Error(w, "Internal server error", http.StatusInternalServerError) return } diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index 39b9324..76b9b3d 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -138,7 +138,9 @@ func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httpte config := goconf.NewConfigFile() config.AddOption("tokens", TokenIdForTest, pubkey.Name()) - proxy, err = NewProxyServer(r, "0.0", config) + logger := signaling.NewLoggerForTest(t) + ctx := signaling.NewLoggerContext(t.Context(), logger) + proxy, err = NewProxyServer(ctx, r, "0.0", config) require.NoError(err) server := httptest.NewServer(r) @@ -150,7 +152,6 @@ func newProxyServerForTest(t *testing.T) (*ProxyServer, *rsa.PrivateKey, *httpte } func TestTokenValid(t *testing.T) { - signaling.CatchLogForTest(t) proxy, key, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ @@ -173,7 +174,6 @@ func TestTokenValid(t *testing.T) { } func TestTokenNotSigned(t *testing.T) { - signaling.CatchLogForTest(t) proxy, _, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ @@ -198,7 +198,6 @@ func TestTokenNotSigned(t *testing.T) { } func TestTokenUnknown(t *testing.T) { - signaling.CatchLogForTest(t) proxy, key, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ @@ -223,7 +222,6 @@ func TestTokenUnknown(t *testing.T) { } func TestTokenInFuture(t *testing.T) { - signaling.CatchLogForTest(t) proxy, key, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ @@ -248,7 +246,6 @@ func TestTokenInFuture(t *testing.T) { } func TestTokenExpired(t *testing.T) { - signaling.CatchLogForTest(t) proxy, key, _ := newProxyServerForTest(t) claims := &signaling.TokenClaims{ @@ -305,7 +302,6 @@ func TestPublicIPs(t *testing.T) { } func TestWebsocketFeatures(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) _, _, server := newProxyServerForTest(t) @@ -333,7 +329,6 @@ func TestWebsocketFeatures(t *testing.T) { } func TestProxyCreateSession(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) _, key, server := newProxyServerForTest(t) @@ -485,7 +480,6 @@ func NewPublisherTestMCU(t *testing.T) *PublisherTestMCU { } func TestProxyPublisherBandwidth(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) @@ -605,7 +599,6 @@ func (m *HangingTestMCU) NewSubscriber(ctx context.Context, listener signaling.M } func TestProxyCancelOnClose(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) @@ -684,7 +677,6 @@ func (m *CodecsTestMCU) NewPublisher(ctx context.Context, listener signaling.Mcu } func TestProxyCodecs(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) @@ -767,7 +759,6 @@ func NewStreamTestMCU(t *testing.T, streams []signaling.PublisherStream) *Stream } func TestProxyStreams(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) @@ -992,7 +983,6 @@ func (m *RemoteSubscriberTestMCU) NewRemoteSubscriber(ctx context.Context, liste } func TestProxyRemoteSubscriber(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) @@ -1087,7 +1077,6 @@ func TestProxyRemoteSubscriber(t *testing.T) { } func TestProxyCloseRemoteOnSessionClose(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) @@ -1250,7 +1239,6 @@ func (p *UnpublishRemoteTestPublisher) UnpublishRemote(ctx context.Context, remo } func TestProxyUnpublishRemote(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) @@ -1367,7 +1355,6 @@ func TestProxyUnpublishRemote(t *testing.T) { } func TestProxyUnpublishRemotePublisherClosed(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) @@ -1499,7 +1486,6 @@ func TestProxyUnpublishRemotePublisherClosed(t *testing.T) { } func TestProxyUnpublishRemoteOnSessionClose(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) require := require.New(t) proxy, key, server := newProxyServerForTest(t) diff --git a/proxy/proxy_session.go b/proxy/proxy_session.go index c97acc5..750ac2a 100644 --- a/proxy/proxy_session.go +++ b/proxy/proxy_session.go @@ -24,7 +24,6 @@ package main import ( "context" "fmt" - "log" "sync" "sync/atomic" "time" @@ -46,6 +45,7 @@ type remotePublisherData struct { } type ProxySession struct { + logger signaling.Logger proxy *ProxyServer id signaling.PublicSessionId sid uint64 @@ -79,6 +79,7 @@ type ProxySession struct { func NewProxySession(proxy *ProxyServer, sid uint64, id signaling.PublicSessionId) *ProxySession { ctx, closeFunc := context.WithCancel(context.Background()) result := &ProxySession{ + logger: proxy.logger, proxy: proxy, id: id, sid: sid, @@ -169,7 +170,7 @@ func (s *ProxySession) SetClient(client *ProxyClient) *ProxyClient { func (s *ProxySession) OnUpdateOffer(client signaling.McuClient, offer api.StringMap) { id := s.proxy.GetClientId(client) if id == "" { - log.Printf("Received offer %+v from unknown %s client %s (%+v)", offer, client.StreamType(), client.Id(), client) + s.logger.Printf("Received offer %+v from unknown %s client %s (%+v)", offer, client.StreamType(), client.Id(), client) return } @@ -189,7 +190,7 @@ func (s *ProxySession) OnUpdateOffer(client signaling.McuClient, offer api.Strin func (s *ProxySession) OnIceCandidate(client signaling.McuClient, candidate any) { id := s.proxy.GetClientId(client) if id == "" { - log.Printf("Received candidate %+v from unknown %s client %s (%+v)", candidate, client.StreamType(), client.Id(), client) + s.logger.Printf("Received candidate %+v from unknown %s client %s (%+v)", candidate, client.StreamType(), client.Id(), client) return } @@ -222,7 +223,7 @@ func (s *ProxySession) sendMessage(message *signaling.ProxyServerMessage) { func (s *ProxySession) OnIceCompleted(client signaling.McuClient) { id := s.proxy.GetClientId(client) if id == "" { - log.Printf("Received ice completed event from unknown %s client %s (%+v)", client.StreamType(), client.Id(), client) + s.logger.Printf("Received ice completed event from unknown %s client %s (%+v)", client.StreamType(), client.Id(), client) return } @@ -239,7 +240,7 @@ func (s *ProxySession) OnIceCompleted(client signaling.McuClient) { func (s *ProxySession) SubscriberSidUpdated(subscriber signaling.McuSubscriber) { id := s.proxy.GetClientId(subscriber) if id == "" { - log.Printf("Received subscriber sid updated event from unknown %s subscriber %s (%+v)", subscriber.StreamType(), subscriber.Id(), subscriber) + s.logger.Printf("Received subscriber sid updated event from unknown %s subscriber %s (%+v)", subscriber.StreamType(), subscriber.Id(), subscriber) return } @@ -363,7 +364,7 @@ func (s *ProxySession) clearRemotePublishers() { for publisher, entries := range remotePublishers { for _, data := range entries { if err := publisher.UnpublishRemote(context.Background(), s.PublicId(), data.hostname, data.port, data.rtcpPort); err != nil { - log.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), publisher.Id(), data.hostname, err) + s.logger.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), publisher.Id(), data.hostname, err) } } } @@ -476,7 +477,7 @@ func (s *ProxySession) OnRemotePublisherDeleted(publisherId signaling.PublicSess delete(s.subscribers, id) delete(s.subscriberIds, sub) - log.Printf("Remote subscriber %s was closed, closing %s subscriber %s", publisherId, sub.StreamType(), sub.Id()) + s.logger.Printf("Remote subscriber %s was closed, closing %s subscriber %s", publisherId, sub.StreamType(), sub.Id()) go sub.Close(context.Background()) } } diff --git a/proxy/proxy_tokens_etcd.go b/proxy/proxy_tokens_etcd.go index f542d62..ccaa837 100644 --- a/proxy/proxy_tokens_etcd.go +++ b/proxy/proxy_tokens_etcd.go @@ -25,7 +25,6 @@ import ( "bytes" "context" "fmt" - "log" "strings" "sync/atomic" "time" @@ -46,14 +45,15 @@ type tokenCacheEntry struct { } type tokensEtcd struct { + logger signaling.Logger client *signaling.EtcdClient tokenFormats atomic.Value tokenCache *signaling.LruCache[*tokenCacheEntry] } -func NewProxyTokensEtcd(config *goconf.ConfigFile) (ProxyTokens, error) { - client, err := signaling.NewEtcdClient(config, "tokens") +func NewProxyTokensEtcd(logger signaling.Logger, config *goconf.ConfigFile) (ProxyTokens, error) { + client, err := signaling.NewEtcdClient(logger, config, "tokens") if err != nil { return nil, err } @@ -63,6 +63,7 @@ func NewProxyTokensEtcd(config *goconf.ConfigFile) (ProxyTokens, error) { } result := &tokensEtcd{ + logger: logger, client: client, tokenCache: signaling.NewLruCache[*tokenCacheEntry](tokenCacheSize), } @@ -94,7 +95,7 @@ func (t *tokensEtcd) getByKey(id string, key string) (*ProxyToken, error) { if len(resp.Kvs) == 0 { return nil, nil } else if len(resp.Kvs) > 1 { - log.Printf("Received multiple keys for %s, using last", key) + t.logger.Printf("Received multiple keys for %s, using last", key) } keyValue := resp.Kvs[len(resp.Kvs)-1].Value @@ -123,7 +124,7 @@ func (t *tokensEtcd) Get(id string) (*ProxyToken, error) { for _, k := range t.getKeys(id) { token, err := t.getByKey(id, k) if err != nil { - log.Printf("Could not get public key from %s for %s: %s", k, id, err) + t.logger.Printf("Could not get public key from %s for %s: %s", k, id, err) continue } else if token == nil { continue @@ -151,18 +152,18 @@ func (t *tokensEtcd) load(config *goconf.ConfigFile, ignoreErrors bool) error { } t.tokenFormats.Store(tokenFormats) - log.Printf("Using %v as token formats", tokenFormats) + t.logger.Printf("Using %v as token formats", tokenFormats) return nil } func (t *tokensEtcd) Reload(config *goconf.ConfigFile) { if err := t.load(config, true); err != nil { - log.Printf("Error reloading etcd tokens: %s", err) + t.logger.Printf("Error reloading etcd tokens: %s", err) } } func (t *tokensEtcd) Close() { if err := t.client.Close(); err != nil { - log.Printf("Error while closing etcd client: %s", err) + t.logger.Printf("Error while closing etcd client: %s", err) } } diff --git a/proxy/proxy_tokens_etcd_test.go b/proxy/proxy_tokens_etcd_test.go index 8b23ff3..be600f2 100644 --- a/proxy/proxy_tokens_etcd_test.go +++ b/proxy/proxy_tokens_etcd_test.go @@ -115,7 +115,8 @@ func newTokensEtcdForTesting(t *testing.T) (*tokensEtcd, *embed.Etcd) { cfg.AddOption("etcd", "endpoints", etcd.Config().ListenClientUrls[0].String()) cfg.AddOption("tokens", "keyformat", "/%s, /testing/%s/key") - tokens, err := NewProxyTokensEtcd(cfg) + logger := signaling.NewLoggerForTest(t) + tokens, err := NewProxyTokensEtcd(logger, cfg) require.NoError(t, err) t.Cleanup(func() { tokens.Close() @@ -155,7 +156,6 @@ func generateAndSaveKey(t *testing.T, etcd *embed.Etcd, name string) *rsa.Privat } func TestProxyTokensEtcd(t *testing.T) { - signaling.CatchLogForTest(t) assert := assert.New(t) tokens, etcd := newTokensEtcdForTesting(t) diff --git a/proxy/proxy_tokens_static.go b/proxy/proxy_tokens_static.go index 37c23e4..8dc2118 100644 --- a/proxy/proxy_tokens_static.go +++ b/proxy/proxy_tokens_static.go @@ -23,7 +23,6 @@ package main import ( "fmt" - "log" "os" "slices" "sync/atomic" @@ -35,11 +34,14 @@ import ( ) type tokensStatic struct { + logger signaling.Logger tokenKeys atomic.Value } -func NewProxyTokensStatic(config *goconf.ConfigFile) (ProxyTokens, error) { - result := &tokensStatic{} +func NewProxyTokensStatic(logger signaling.Logger, config *goconf.ConfigFile) (ProxyTokens, error) { + result := &tokensStatic{ + logger: logger, + } if err := result.load(config, false); err != nil { return nil, err } @@ -74,7 +76,7 @@ func (t *tokensStatic) load(config *goconf.ConfigFile, ignoreErrors bool) error return fmt.Errorf("no filename given for token %s", id) } - log.Printf("No filename given for token %s, ignoring", id) + t.logger.Printf("No filename given for token %s, ignoring", id) continue } @@ -84,7 +86,7 @@ func (t *tokensStatic) load(config *goconf.ConfigFile, ignoreErrors bool) error return fmt.Errorf("could not read public key from %s: %s", filename, err) } - log.Printf("Could not read public key from %s, ignoring: %s", filename, err) + t.logger.Printf("Could not read public key from %s, ignoring: %s", filename, err) continue } key, err := jwt.ParseRSAPublicKeyFromPEM(keyData) @@ -93,7 +95,7 @@ func (t *tokensStatic) load(config *goconf.ConfigFile, ignoreErrors bool) error return fmt.Errorf("could not parse public key from %s: %s", filename, err) } - log.Printf("Could not parse public key from %s, ignoring: %s", filename, err) + t.logger.Printf("Could not parse public key from %s, ignoring: %s", filename, err) continue } @@ -104,14 +106,14 @@ func (t *tokensStatic) load(config *goconf.ConfigFile, ignoreErrors bool) error } if len(tokenKeys) == 0 { - log.Printf("No token keys loaded") + t.logger.Printf("No token keys loaded") } else { var keyIds []string for k := range tokenKeys { keyIds = append(keyIds, k) } slices.Sort(keyIds) - log.Printf("Enabled token keys: %v", keyIds) + t.logger.Printf("Enabled token keys: %v", keyIds) } t.setTokenKeys(tokenKeys) return nil @@ -119,7 +121,7 @@ func (t *tokensStatic) load(config *goconf.ConfigFile, ignoreErrors bool) error func (t *tokensStatic) Reload(config *goconf.ConfigFile) { if err := t.load(config, true); err != nil { - log.Printf("Error reloading static tokens: %s", err) + t.logger.Printf("Error reloading static tokens: %s", err) } } diff --git a/proxy_config_etcd.go b/proxy_config_etcd.go index faf1582..94365c4 100644 --- a/proxy_config_etcd.go +++ b/proxy_config_etcd.go @@ -25,7 +25,6 @@ import ( "context" "encoding/json" "errors" - "log" "sync" "time" @@ -34,8 +33,9 @@ import ( ) type proxyConfigEtcd struct { - mu sync.Mutex - proxy McuProxy // +checklocksignore: Only written to from constructor. + logger Logger + mu sync.Mutex + proxy McuProxy // +checklocksignore: Only written to from constructor. client *EtcdClient keyPrefix string @@ -48,7 +48,7 @@ type proxyConfigEtcd struct { closeFunc context.CancelFunc } -func NewProxyConfigEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient, proxy McuProxy) (ProxyConfig, error) { +func NewProxyConfigEtcd(logger Logger, config *goconf.ConfigFile, etcdClient *EtcdClient, proxy McuProxy) (ProxyConfig, error) { if !etcdClient.IsConfigured() { return nil, errors.New("no etcd endpoints configured") } @@ -56,7 +56,8 @@ func NewProxyConfigEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient, proxy closeCtx, closeFunc := context.WithCancel(context.Background()) result := &proxyConfigEtcd{ - proxy: proxy, + logger: logger, + proxy: proxy, client: etcdClient, keyInfos: make(map[string]*ProxyInformationEtcd), @@ -118,9 +119,9 @@ func (p *proxyConfigEtcd) EtcdClientCreated(client *EtcdClient) { if errors.Is(err, context.Canceled) { return } else if errors.Is(err, context.DeadlineExceeded) { - log.Printf("Timeout getting initial list of proxy URLs, retry in %s", backoff.NextWait()) + p.logger.Printf("Timeout getting initial list of proxy URLs, retry in %s", backoff.NextWait()) } else { - log.Printf("Could not get initial list of proxy URLs, retry in %s: %s", backoff.NextWait(), err) + p.logger.Printf("Could not get initial list of proxy URLs, retry in %s: %s", backoff.NextWait(), err) } backoff.Wait(p.closeCtx) @@ -139,7 +140,7 @@ func (p *proxyConfigEtcd) EtcdClientCreated(client *EtcdClient) { for p.closeCtx.Err() == nil { var err error if nextRevision, err = client.Watch(p.closeCtx, p.keyPrefix, nextRevision, p, clientv3.WithPrefix()); err != nil { - log.Printf("Error processing watch for %s (%s), retry in %s", p.keyPrefix, err, backoff.NextWait()) + p.logger.Printf("Error processing watch for %s (%s), retry in %s", p.keyPrefix, err, backoff.NextWait()) backoff.Wait(p.closeCtx) continue } @@ -148,7 +149,7 @@ func (p *proxyConfigEtcd) EtcdClientCreated(client *EtcdClient) { backoff.Reset() prevRevision = nextRevision } else { - log.Printf("Processing watch for %s interrupted, retry in %s", p.keyPrefix, backoff.NextWait()) + p.logger.Printf("Processing watch for %s interrupted, retry in %s", p.keyPrefix, backoff.NextWait()) backoff.Wait(p.closeCtx) } } @@ -168,11 +169,11 @@ func (p *proxyConfigEtcd) getProxyUrls(ctx context.Context, client *EtcdClient, func (p *proxyConfigEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data []byte, prevValue []byte) { var info ProxyInformationEtcd if err := json.Unmarshal(data, &info); err != nil { - log.Printf("Could not decode proxy information %s: %s", string(data), err) + p.logger.Printf("Could not decode proxy information %s: %s", string(data), err) return } if err := info.CheckValid(); err != nil { - log.Printf("Received invalid proxy information %s: %s", string(data), err) + p.logger.Printf("Received invalid proxy information %s: %s", string(data), err) return } @@ -187,7 +188,7 @@ func (p *proxyConfigEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data [] } if otherKey, otherFound := p.urlToKey[info.Address]; otherFound && otherKey != key { - log.Printf("Address %s is already registered for key %s, ignoring %s", info.Address, otherKey, key) + p.logger.Printf("Address %s is already registered for key %s, ignoring %s", info.Address, otherKey, key) return } @@ -196,11 +197,11 @@ func (p *proxyConfigEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data [] p.proxy.KeepConnection(info.Address) } else { if err := p.proxy.AddConnection(false, info.Address); err != nil { - log.Printf("Could not create proxy connection to %s: %s", info.Address, err) + p.logger.Printf("Could not create proxy connection to %s: %s", info.Address, err) return } - log.Printf("Added new connection to %s (from %s)", info.Address, key) + p.logger.Printf("Added new connection to %s (from %s)", info.Address, key) p.keyInfos[key] = &info p.urlToKey[info.Address] = key } @@ -223,6 +224,6 @@ func (p *proxyConfigEtcd) removeEtcdProxyLocked(key string) { delete(p.keyInfos, key) delete(p.urlToKey, info.Address) - log.Printf("Removing connection to %s (from %s)", info.Address, key) + p.logger.Printf("Removing connection to %s (from %s)", info.Address, key) p.proxy.RemoveConnection(info.Address) } diff --git a/proxy_config_etcd_test.go b/proxy_config_etcd_test.go index 353f690..814b9c3 100644 --- a/proxy_config_etcd_test.go +++ b/proxy_config_etcd_test.go @@ -43,7 +43,8 @@ func newProxyConfigEtcd(t *testing.T, proxy McuProxy) (*embed.Etcd, ProxyConfig) etcd, client := NewEtcdClientForTest(t) cfg := goconf.NewConfigFile() cfg.AddOption("mcu", "keyprefix", "proxies/") - p, err := NewProxyConfigEtcd(cfg, client, proxy) + logger := NewLoggerForTest(t) + p, err := NewProxyConfigEtcd(logger, cfg, client, proxy) require.NoError(t, err) t.Cleanup(func() { p.Stop() @@ -60,11 +61,10 @@ func SetEtcdProxy(t *testing.T, etcd *embed.Etcd, path string, proxy *TestProxyI func TestProxyConfigEtcd(t *testing.T) { t.Parallel() - CatchLogForTest(t) proxy := newMcuProxyForConfig(t) etcd, config := newProxyConfigEtcd(t, proxy) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(t.Context(), time.Second) defer cancel() SetEtcdProxy(t, etcd, "proxies/a", &TestProxyInformationEtcd{ diff --git a/proxy_config_static.go b/proxy_config_static.go index dcdaa59..0dcaa55 100644 --- a/proxy_config_static.go +++ b/proxy_config_static.go @@ -23,7 +23,6 @@ package signaling import ( "errors" - "log" "maps" "net" "net/url" @@ -40,8 +39,9 @@ type ipList struct { } type proxyConfigStatic struct { - mu sync.Mutex - proxy McuProxy + logger Logger + mu sync.Mutex + proxy McuProxy dnsMonitor *DnsMonitor // +checklocks:mu @@ -51,8 +51,9 @@ type proxyConfigStatic struct { connectionsMap map[string]*ipList } -func NewProxyConfigStatic(config *goconf.ConfigFile, proxy McuProxy, dnsMonitor *DnsMonitor) (ProxyConfig, error) { +func NewProxyConfigStatic(logger Logger, config *goconf.ConfigFile, proxy McuProxy, dnsMonitor *DnsMonitor) (ProxyConfig, error) { result := &proxyConfigStatic{ + logger: logger, proxy: proxy, dnsMonitor: dnsMonitor, connectionsMap: make(map[string]*ipList), @@ -100,7 +101,7 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool return err } - log.Printf("Could not parse URL %s: %s", u, err) + p.logger.Printf("Could not parse URL %s: %s", u, err) continue } @@ -121,7 +122,7 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool return err } - log.Printf("Could not create proxy connection to %s: %s", u, err) + p.logger.Printf("Could not create proxy connection to %s: %s", u, err) continue } } @@ -198,7 +199,7 @@ func (p *proxyConfigStatic) onLookup(entry *DnsMonitorEntry, all []net.IP, added if len(added) > 0 { if err := p.proxy.AddConnection(true, u, added...); err != nil { - log.Printf("Could not add proxy connection to %s with %+v: %s", u, added, err) + p.logger.Printf("Could not add proxy connection to %s with %+v: %s", u, added, err) } } diff --git a/proxy_config_static_test.go b/proxy_config_static_test.go index 70884e8..5354d48 100644 --- a/proxy_config_static_test.go +++ b/proxy_config_static_test.go @@ -38,7 +38,8 @@ func newProxyConfigStatic(t *testing.T, proxy McuProxy, dns bool, urls ...string cfg.AddOption("mcu", "dnsdiscovery", "true") } dnsMonitor := newDnsMonitorForTest(t, time.Hour) // will be updated manually - p, err := NewProxyConfigStatic(cfg, proxy, dnsMonitor) + logger := NewLoggerForTest(t) + p, err := NewProxyConfigStatic(logger, cfg, proxy, dnsMonitor) require.NoError(t, err) t.Cleanup(func() { p.Stop() @@ -56,7 +57,6 @@ func updateProxyConfigStatic(t *testing.T, config ProxyConfig, dns bool, urls .. } func TestProxyConfigStaticSimple(t *testing.T) { - CatchLogForTest(t) proxy := newMcuProxyForConfig(t) config, _ := newProxyConfigStatic(t, proxy, false, "https://foo/") proxy.Expect("add", "https://foo/") @@ -73,7 +73,6 @@ func TestProxyConfigStaticSimple(t *testing.T) { } func TestProxyConfigStaticDNS(t *testing.T) { - CatchLogForTest(t) lookup := newMockDnsLookupForTest(t) proxy := newMcuProxyForConfig(t) config, dnsMonitor := newProxyConfigStatic(t, proxy, true, "https://foo/") diff --git a/remotesession.go b/remotesession.go index c548f9e..0787785 100644 --- a/remotesession.go +++ b/remotesession.go @@ -26,12 +26,12 @@ import ( "encoding/json" "errors" "fmt" - "log" "sync/atomic" "time" ) type RemoteSession struct { + logger Logger hub *Hub client *Client remoteClient *GrpcClient @@ -42,6 +42,7 @@ type RemoteSession struct { func NewRemoteSession(hub *Hub, client *Client, remoteClient *GrpcClient, sessionId PublicSessionId) (*RemoteSession, error) { remoteSession := &RemoteSession{ + logger: hub.logger, hub: hub, client: client, remoteClient: remoteClient, @@ -97,7 +98,7 @@ func (s *RemoteSession) OnProxyMessage(msg *ServerSessionMessage) error { func (s *RemoteSession) OnProxyClose(err error) { if err != nil { - log.Printf("Proxy connection for session %s to %s was closed with error: %s", s.sessionId, s.remoteClient.Target(), err) + s.logger.Printf("Proxy connection for session %s to %s was closed with error: %s", s.sessionId, s.remoteClient.Target(), err) } s.Close() } @@ -145,7 +146,7 @@ func (s *RemoteSession) OnClosed(client HandlerClient) { func (s *RemoteSession) OnMessageReceived(client HandlerClient, message []byte) { if err := s.sendProxyMessage(message); err != nil { - log.Printf("Error sending %s to the proxy for session %s: %s", string(message), s.sessionId, err) + s.logger.Printf("Error sending %s to the proxy for session %s: %s", string(message), s.sessionId, err) s.Close() } } diff --git a/room.go b/room.go index 372cf5b..6e75077 100644 --- a/room.go +++ b/room.go @@ -26,7 +26,6 @@ import ( "context" "encoding/json" "fmt" - "log" "maps" "net/url" "strconv" @@ -65,6 +64,7 @@ func init() { type Room struct { id string + logger Logger hub *Hub events AsyncEvents backend *Backend @@ -108,6 +108,7 @@ func getRoomIdForBackend(id string, backend *Backend) string { func NewRoom(roomId string, properties json.RawMessage, hub *Hub, events AsyncEvents, backend *Backend) (*Room, error) { room := &Room{ id: roomId, + logger: hub.logger, hub: hub, events: events, backend: backend, @@ -223,7 +224,7 @@ func (r *Room) ProcessBackendRoomRequest(message *AsyncMessage) { case "asyncroom": r.processBackendRoomRequestAsyncRoom(message.AsyncRoom) default: - log.Printf("Unsupported backend room request with type %s in %s: %+v", message.Type, r.id, message) + r.logger.Printf("Unsupported backend room request with type %s in %s: %+v", message.Type, r.id, message) } } @@ -231,9 +232,9 @@ func (r *Room) processBackendRoomRequestRoom(message *BackendServerRoomRequest) received := message.ReceivedTime if last, found := r.lastRoomRequests[message.Type]; found && last > received { if msg, err := json.Marshal(message); err == nil { - log.Printf("Ignore old backend room request for %s: %s", r.Id(), string(msg)) + r.logger.Printf("Ignore old backend room request for %s: %s", r.Id(), string(msg)) } else { - log.Printf("Ignore old backend room request for %s: %+v", r.Id(), message) + r.logger.Printf("Ignore old backend room request for %s: %+v", r.Id(), message) } return } @@ -261,10 +262,10 @@ func (r *Room) processBackendRoomRequestRoom(message *BackendServerRoomRequest) case TransientActionDelete: r.RemoveTransientData(message.Transient.Key) default: - log.Printf("Unsupported transient action in room %s: %+v", r.Id(), message.Transient) + r.logger.Printf("Unsupported transient action in room %s: %+v", r.Id(), message.Transient) } default: - log.Printf("Unsupported backend room request with type %s in %s: %+v", message.Type, r.Id(), message) + r.logger.Printf("Unsupported backend room request with type %s in %s: %+v", message.Type, r.Id(), message) } } @@ -276,7 +277,7 @@ func (r *Room) processBackendRoomRequestAsyncRoom(message *AsyncRoomMessage) { r.publishUsersChangedWithInternal() } default: - log.Printf("Unsupported async room request with type %s in %s: %+v", message.Type, r.Id(), message) + r.logger.Printf("Unsupported async room request with type %s in %s: %+v", message.Type, r.Id(), message) } } @@ -285,7 +286,7 @@ func (r *Room) AddSession(session Session, sessionData json.RawMessage) { if len(sessionData) > 0 { roomSessionData = &RoomSessionData{} if err := json.Unmarshal(sessionData, roomSessionData); err != nil { - log.Printf("Error decoding room session data \"%s\": %s", string(sessionData), err) + r.logger.Printf("Error decoding room session data \"%s\": %s", string(sessionData), err) roomSessionData = nil } } @@ -319,7 +320,7 @@ func (r *Room) AddSession(session Session, sessionData json.RawMessage) { } if roomSessionData != nil { r.roomSessionData[sid] = roomSessionData - log.Printf("Session %s sent room session data %+v", session.PublicId(), roomSessionData) + r.logger.Printf("Session %s sent room session data %+v", session.PublicId(), roomSessionData) } r.mu.Unlock() if !found { @@ -344,7 +345,7 @@ func (r *Room) AddSession(session Session, sessionData json.RawMessage) { ClientType: session.ClientType(), }, }); err != nil { - log.Printf("Error publishing joined event for session %s: %s", sid, err) + r.logger.Printf("Error publishing joined event for session %s: %s", sid, err) } } @@ -402,7 +403,7 @@ func (r *Room) notifySessionJoined(sessionId PublicSessionId) { Type: "message", Message: msg, }); err != nil { - log.Printf("Error publishing joined events to session %s: %s", sessionId, err) + r.logger.Printf("Error publishing joined events to session %s: %s", sessionId, err) } // Notify about initial flags of virtual sessions. @@ -434,7 +435,7 @@ func (r *Room) notifySessionJoined(sessionId PublicSessionId) { Type: "message", Message: msg, }); err != nil { - log.Printf("Error publishing initial flags to session %s: %s", sessionId, err) + r.logger.Printf("Error publishing initial flags to session %s: %s", sessionId, err) } } } @@ -526,7 +527,7 @@ func (r *Room) UpdateProperties(properties json.RawMessage) { }, } if err := r.publish(message); err != nil { - log.Printf("Could not publish update properties message in room %s: %s", r.Id(), err) + r.logger.Printf("Could not publish update properties message in room %s: %s", r.Id(), err) } } @@ -567,7 +568,7 @@ func (r *Room) PublishSessionJoined(session Session, sessionData *RoomSessionDat message.Event.Join[0].Federated = session.ClientType() == HelloClientTypeFederation } if err := r.publish(message); err != nil { - log.Printf("Could not publish session joined message in room %s: %s", r.Id(), err) + r.logger.Printf("Could not publish session joined message in room %s: %s", r.Id(), err) } } @@ -588,7 +589,7 @@ func (r *Room) PublishSessionLeft(session Session) { }, } if err := r.publish(message); err != nil { - log.Printf("Could not publish session left message in room %s: %s", r.Id(), err) + r.logger.Printf("Could not publish session left message in room %s: %s", r.Id(), err) } if session.ClientType() == HelloClientTypeInternal { @@ -604,7 +605,8 @@ func (r *Room) getClusteredInternalSessionsRLocked() (internal map[PublicSession r.mu.RUnlock() defer r.mu.RLock() - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx := NewLoggerContext(context.Background(), r.logger) + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() var mu sync.Mutex @@ -616,7 +618,7 @@ func (r *Room) getClusteredInternalSessionsRLocked() (internal map[PublicSession clientInternal, clientVirtual, err := c.GetInternalSessions(ctx, r.Id(), r.Backend().Urls()) if err != nil { - log.Printf("Received error while getting internal sessions for %s@%s from %s: %s", r.Id(), r.Backend().Id(), c.Target(), err) + r.logger.Printf("Received error while getting internal sessions for %s@%s from %s: %s", r.Id(), r.Backend().Id(), c.Target(), err) return } @@ -786,7 +788,7 @@ func (r *Room) PublishUsersInCallChanged(changed []api.StringMap, users []api.St r.mu.Lock() if !r.inCallSessions[session] { r.inCallSessions[session] = true - log.Printf("Session %s joined call %s", session.PublicId(), r.id) + r.logger.Printf("Session %s joined call %s", session.PublicId(), r.id) } r.mu.Unlock() } else { @@ -815,7 +817,7 @@ func (r *Room) PublishUsersInCallChanged(changed []api.StringMap, users []api.St }, } if err := r.publish(message); err != nil { - log.Printf("Could not publish incall message in room %s: %s", r.Id(), err) + r.logger.Printf("Could not publish incall message in room %s: %s", r.Id(), err) } } @@ -849,7 +851,7 @@ func (r *Room) PublishUsersInCallChangedAll(inCall int) { return } - log.Printf("Sessions %v joined call %s", joined, r.id) + r.logger.Printf("Sessions %v joined call %s", joined, r.id) } else if len(r.inCallSessions) > 0 { // Perform actual leaving asynchronously. ch := make(chan *ClientSession, 1) @@ -902,7 +904,7 @@ func (r *Room) PublishUsersInCallChangedAll(inCall int) { for _, session := range notify { if !session.SendMessage(message) { - log.Printf("Could not send incall message from room %s to %s", r.Id(), session.PublicId()) + r.logger.Printf("Could not send incall message from room %s to %s", r.Id(), session.PublicId()) } } } @@ -924,7 +926,7 @@ func (r *Room) PublishUsersChanged(changed []api.StringMap, users []api.StringMa }, } if err := r.publish(message); err != nil { - log.Printf("Could not publish users changed message in room %s: %s", r.Id(), err) + r.logger.Printf("Could not publish users changed message in room %s: %s", r.Id(), err) } } @@ -978,7 +980,7 @@ func (r *Room) NotifySessionChanged(session Session, flags SessionChangeFlag) { r.mu.Lock() if !r.inCallSessions[session] { r.inCallSessions[session] = true - log.Printf("Session %s joined call %s", session.PublicId(), r.id) + r.logger.Printf("Session %s joined call %s", session.PublicId(), r.id) } r.mu.Unlock() case 2: @@ -1003,7 +1005,7 @@ func (r *Room) publishUsersChangedWithInternal() { } if err := r.publish(message); err != nil { - log.Printf("Could not publish users changed message in room %s: %s", r.Id(), err) + r.logger.Printf("Could not publish users changed message in room %s: %s", r.Id(), err) } } @@ -1021,7 +1023,7 @@ func (r *Room) publishSessionFlagsChanged(session *VirtualSession) { }, } if err := r.publish(message); err != nil { - log.Printf("Could not publish flags changed message in room %s: %s", r.Id(), err) + r.logger.Printf("Could not publish flags changed message in room %s: %s", r.Id(), err) } } @@ -1040,7 +1042,7 @@ func (r *Room) publishActiveSessions() (int, *sync.WaitGroup) { u += PathToOcsSignalingBackend parsed, err := url.Parse(u) if err != nil { - log.Printf("Could not parse backend url %s: %s", u, err) + r.logger.Printf("Could not parse backend url %s: %s", u, err) continue } @@ -1087,16 +1089,17 @@ func (r *Room) publishActiveSessions() (int, *sync.WaitGroup) { return 0, &wg } var count int + ctx := NewLoggerContext(context.Background(), r.logger) for u, e := range entries { wg.Add(1) count += len(e) go func(url *url.URL, entries []BackendPingEntry) { defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), r.hub.backendTimeout) + sendCtx, cancel := context.WithTimeout(ctx, r.hub.backendTimeout) defer cancel() - if err := r.hub.roomPing.SendPings(ctx, r.id, url, entries); err != nil { - log.Printf("Error pinging room %s for active entries %+v: %s", r.id, entries, err) + if err := r.hub.roomPing.SendPings(sendCtx, r.id, url, entries); err != nil { + r.logger.Printf("Error pinging room %s for active entries %+v: %s", r.id, entries, err) } }(urls[u], e) } @@ -1120,7 +1123,7 @@ func (r *Room) publishRoomMessage(message *BackendRoomMessageRequest) { }, } if err := r.publish(msg); err != nil { - log.Printf("Could not publish room message in room %s: %s", r.Id(), err) + r.logger.Printf("Could not publish room message in room %s: %s", r.Id(), err) } } @@ -1147,7 +1150,7 @@ func (r *Room) publishSwitchTo(message *BackendRoomSwitchToMessageRequest) { Type: "message", Message: msg, }); err != nil { - log.Printf("Error publishing switchto event to session %s: %s", sessionId, err) + r.logger.Printf("Error publishing switchto event to session %s: %s", sessionId, err) } }(sessionId) } @@ -1175,7 +1178,7 @@ func (r *Room) publishSwitchTo(message *BackendRoomSwitchToMessageRequest) { Type: "message", Message: msg, }); err != nil { - log.Printf("Error publishing switchto event to session %s: %s", sessionId, err) + r.logger.Printf("Error publishing switchto event to session %s: %s", sessionId, err) } }(sessionId, details) } diff --git a/room_ping.go b/room_ping.go index 3e9d708..7958370 100644 --- a/room_ping.go +++ b/room_ping.go @@ -23,7 +23,6 @@ package signaling import ( "context" - "log" "net/url" "slices" "sync" @@ -100,7 +99,7 @@ loop: case <-p.closer.C: break loop case <-ticker.C: - p.publishActiveSessions() + p.publishActiveSessions(context.Background()) } } } @@ -114,20 +113,21 @@ func (p *RoomPing) getAndClearEntries() map[string]*pingEntries { return entries } -func (p *RoomPing) publishEntries(entries *pingEntries, timeout time.Duration) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) +func (p *RoomPing) publishEntries(ctx context.Context, entries *pingEntries, timeout time.Duration) { + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() limit, _, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit) if !found || limit <= 0 { // Limit disabled while waiting for the next iteration, fallback to sending // one request per room. + logger := LoggerFromContext(ctx) for roomId, e := range entries.entries { - ctx2, cancel2 := context.WithTimeout(context.Background(), timeout) + ctx2, cancel2 := context.WithTimeout(context.WithoutCancel(ctx), timeout) defer cancel2() if err := p.sendPingsDirect(ctx2, roomId, entries.url, e); err != nil { - log.Printf("Error pinging room %s for active entries %+v: %s", roomId, e, err) + logger.Printf("Error pinging room %s for active entries %+v: %s", roomId, e, err) } } return @@ -137,10 +137,10 @@ func (p *RoomPing) publishEntries(entries *pingEntries, timeout time.Duration) { for _, e := range entries.entries { allEntries = append(allEntries, e...) } - p.sendPingsCombined(entries.url, allEntries, limit, timeout) + p.sendPingsCombined(ctx, entries.url, allEntries, limit, timeout) } -func (p *RoomPing) publishActiveSessions() { +func (p *RoomPing) publishActiveSessions(ctx context.Context) { var timeout time.Duration if p.backend.hub != nil { timeout = p.backend.hub.backendTimeout @@ -154,7 +154,7 @@ func (p *RoomPing) publishActiveSessions() { for _, e := range entries { go func(e *pingEntries) { defer wg.Done() - p.publishEntries(e, timeout) + p.publishEntries(ctx, e, timeout) }(e) } wg.Wait() @@ -166,15 +166,16 @@ func (p *RoomPing) sendPingsDirect(ctx context.Context, roomId string, url *url. return p.backend.PerformJSONRequest(ctx, url, request, &response) } -func (p *RoomPing) sendPingsCombined(url *url.URL, entries []BackendPingEntry, limit int, timeout time.Duration) { +func (p *RoomPing) sendPingsCombined(ctx context.Context, url *url.URL, entries []BackendPingEntry, limit int, timeout time.Duration) { + logger := LoggerFromContext(ctx) for tosend := range slices.Chunk(entries, limit) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + subCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() request := NewBackendClientPingRequest("", tosend) var response BackendClientResponse - if err := p.backend.PerformJSONRequest(ctx, url, request, &response); err != nil { - log.Printf("Error sending combined ping session entries %+v to %s: %s", tosend, url, err) + if err := p.backend.PerformJSONRequest(subCtx, url, request, &response); err != nil { + logger.Printf("Error sending combined ping session entries %+v to %s: %s", tosend, url, err) } } } diff --git a/room_ping_test.go b/room_ping_test.go index 58ce74d..a89f536 100644 --- a/room_ping_test.go +++ b/room_ping_test.go @@ -32,7 +32,7 @@ import ( "github.com/stretchr/testify/require" ) -func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) { +func NewRoomPingForTest(ctx context.Context, t *testing.T) (*url.URL, *RoomPing) { require := require.New(t) r := mux.NewRouter() registerBackendHandler(t, r) @@ -45,7 +45,7 @@ func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) { config, err := getTestConfig(server) require.NoError(err) - backend, err := NewBackendClient(config, 1, "0.0", nil) + backend, err := NewBackendClient(ctx, config, 1, "0.0", nil) require.NoError(err) p, err := NewRoomPing(backend, backend.capabilities) @@ -58,11 +58,12 @@ func NewRoomPingForTest(t *testing.T) (*url.URL, *RoomPing) { } func TestSingleRoomPing(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) - u, ping := NewRoomPingForTest(t) + u, ping := NewRoomPingForTest(ctx, t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() room1 := &Room{ @@ -95,16 +96,17 @@ func TestSingleRoomPing(t *testing.T) { } clearPingRequests(t) - ping.publishActiveSessions() + ping.publishActiveSessions(ctx) assert.Empty(getPingRequests(t)) } func TestMultiRoomPing(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) - u, ping := NewRoomPingForTest(t) + u, ping := NewRoomPingForTest(ctx, t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() room1 := &Room{ @@ -131,18 +133,19 @@ func TestMultiRoomPing(t *testing.T) { assert.NoError(ping.SendPings(ctx, room2.Id(), u, entries2)) assert.Empty(getPingRequests(t)) - ping.publishActiveSessions() + ping.publishActiveSessions(ctx) if requests := getPingRequests(t); assert.Len(requests, 1) { assert.Len(requests[0].Ping.Entries, 2) } } func TestMultiRoomPing_Separate(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) - u, ping := NewRoomPingForTest(t) + u, ping := NewRoomPingForTest(ctx, t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() room1 := &Room{ @@ -165,18 +168,19 @@ func TestMultiRoomPing_Separate(t *testing.T) { assert.NoError(ping.SendPings(ctx, room1.Id(), u, entries2)) assert.Empty(getPingRequests(t)) - ping.publishActiveSessions() + ping.publishActiveSessions(ctx) if requests := getPingRequests(t); assert.Len(requests, 1) { assert.Len(requests[0].Ping.Entries, 2) } } func TestMultiRoomPing_DeleteRoom(t *testing.T) { - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) assert := assert.New(t) - u, ping := NewRoomPingForTest(t) + u, ping := NewRoomPingForTest(ctx, t) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() room1 := &Room{ @@ -205,7 +209,7 @@ func TestMultiRoomPing_DeleteRoom(t *testing.T) { ping.DeleteRoom(room2.Id()) - ping.publishActiveSessions() + ping.publishActiveSessions(ctx) if requests := getPingRequests(t); assert.Len(requests, 1) { assert.Len(requests[0].Ping.Entries, 1) } diff --git a/room_test.go b/room_test.go index eadcd19..8901d0b 100644 --- a/room_test.go +++ b/room_test.go @@ -76,18 +76,19 @@ func TestRoom_InCall(t *testing.T) { func TestRoom_Update(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) require.NoError(err) - b, err := NewBackendServer(config, hub, "no-version") + b, err := NewBackendServer(ctx, config, hub, "no-version") require.NoError(err) require.NoError(b.Start(router)) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId) @@ -170,18 +171,19 @@ loop: func TestRoom_Delete(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) require.NoError(err) - b, err := NewBackendServer(config, hub, "no-version") + b, err := NewBackendServer(ctx, config, hub, "no-version") require.NoError(err) require.NoError(b.Start(router)) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client, hello := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId) @@ -267,14 +269,15 @@ loop: func TestRoom_RoomJoinFeatures(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) require.NoError(err) - b, err := NewBackendServer(config, hub, "no-version") + b, err := NewBackendServer(ctx, config, hub, "no-version") require.NoError(err) require.NoError(b.Start(router)) @@ -284,7 +287,7 @@ func TestRoom_RoomJoinFeatures(t *testing.T) { features := []string{"one", "two", "three"} require.NoError(client.SendHelloClientWithFeatures(testDefaultUserId, features)) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() hello := MustSucceed1(t, client.RunUntilHello, ctx) @@ -304,18 +307,19 @@ func TestRoom_RoomJoinFeatures(t *testing.T) { func TestRoom_RoomSessionData(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) require.NoError(err) - b, err := NewBackendServer(config, hub, "no-version") + b, err := NewBackendServer(ctx, config, hub, "no-version") require.NoError(err) require.NoError(b.Start(router)) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client, hello := NewTestClientWithHello(ctx, t, server, hub, authAnonymousUserId) @@ -347,18 +351,19 @@ func TestRoom_RoomSessionData(t *testing.T) { func TestRoom_InCallAll(t *testing.T) { t.Parallel() - CatchLogForTest(t) + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) require := require.New(t) assert := assert.New(t) hub, _, router, server := CreateHubForTest(t) config, err := getTestConfig(server) require.NoError(err) - b, err := NewBackendServer(config, hub, "no-version") + b, err := NewBackendServer(ctx, config, hub, "no-version") require.NoError(err) require.NoError(b.Start(router)) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(ctx, testTimeout) defer cancel() client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") diff --git a/roomsessions_builtin.go b/roomsessions_builtin.go index 02a53c4..9c54bf5 100644 --- a/roomsessions_builtin.go +++ b/roomsessions_builtin.go @@ -24,7 +24,6 @@ package signaling import ( "context" "errors" - "log" "sync" "sync/atomic" ) @@ -117,6 +116,7 @@ func (r *BuiltinRoomSessions) LookupSessionId(ctx context.Context, roomSessionId var wg sync.WaitGroup var result atomic.Value + logger := LoggerFromContext(ctx) for _, client := range clients { wg.Add(1) go func(client *GrpcClient) { @@ -126,10 +126,10 @@ func (r *BuiltinRoomSessions) LookupSessionId(ctx context.Context, roomSessionId if errors.Is(err, context.Canceled) { return } else if err != nil { - log.Printf("Received error while checking for room session id %s on %s: %s", roomSessionId, client.Target(), err) + logger.Printf("Received error while checking for room session id %s on %s: %s", roomSessionId, client.Target(), err) return } else if sid == "" { - log.Printf("Received empty session id for room session id %s from %s", roomSessionId, client.Target()) + logger.Printf("Received empty session id for room session id %s from %s", roomSessionId, client.Target()) return } diff --git a/server/main.go b/server/main.go index 879accb..444d2ca 100644 --- a/server/main.go +++ b/server/main.go @@ -93,7 +93,8 @@ func createTLSListener(addr string, certFile, keyFile string) (net.Listener, err } type Listeners struct { - mu sync.Mutex + logger signaling.Logger // +checklocksignore + mu sync.Mutex // +checklocks:mu listeners []net.Listener } @@ -111,7 +112,7 @@ func (l *Listeners) Close() { for _, listener := range l.listeners { if err := listener.Close(); err != nil { - log.Printf("Error closing listener %s: %s", listener.Addr(), err) + l.logger.Printf("Error closing listener %s: %s", listener.Addr(), err) } } } @@ -126,46 +127,51 @@ func main() { } sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt) signal.Notify(sigChan, syscall.SIGHUP) signal.Notify(sigChan, syscall.SIGUSR1) + stopCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + + logger := log.Default() + stopCtx = signaling.NewLoggerContext(stopCtx, logger) + if *cpuprofile != "" { f, err := os.Create(*cpuprofile) if err != nil { - log.Fatal(err) + logger.Fatal(err) } if err := runtimepprof.StartCPUProfile(f); err != nil { - log.Fatalf("Error writing CPU profile to %s: %s", *cpuprofile, err) + logger.Fatalf("Error writing CPU profile to %s: %s", *cpuprofile, err) } - log.Printf("Writing CPU profile to %s ...", *cpuprofile) + logger.Printf("Writing CPU profile to %s ...", *cpuprofile) defer runtimepprof.StopCPUProfile() } if *memprofile != "" { f, err := os.Create(*memprofile) if err != nil { - log.Fatal(err) + logger.Fatal(err) } defer func() { - log.Printf("Writing Memory profile to %s ...", *memprofile) + logger.Printf("Writing Memory profile to %s ...", *memprofile) runtime.GC() if err := runtimepprof.WriteHeapProfile(f); err != nil { - log.Printf("Error writing Memory profile to %s: %s", *memprofile, err) + logger.Printf("Error writing Memory profile to %s: %s", *memprofile, err) } }() } - log.Printf("Starting up version %s/%s as pid %d", version, runtime.Version(), os.Getpid()) + logger.Printf("Starting up version %s/%s as pid %d", version, runtime.Version(), os.Getpid()) config, err := goconf.ReadConfigFile(*configFlag) if err != nil { - log.Fatal("Could not read configuration: ", err) + logger.Fatal("Could not read configuration: ", err) } - log.Printf("Using a maximum of %d CPUs", runtime.GOMAXPROCS(0)) + logger.Printf("Using a maximum of %d CPUs", runtime.GOMAXPROCS(0)) signaling.RegisterStats() @@ -174,61 +180,61 @@ func main() { natsUrl = nats.DefaultURL } - events, err := signaling.NewAsyncEvents(natsUrl) + events, err := signaling.NewAsyncEvents(stopCtx, natsUrl) if err != nil { - log.Fatal("Could not create async events client: ", err) + logger.Fatal("Could not create async events client: ", err) } defer events.Close() - dnsMonitor, err := signaling.NewDnsMonitor(dnsMonitorInterval) + dnsMonitor, err := signaling.NewDnsMonitor(logger, dnsMonitorInterval) if err != nil { - log.Fatal("Could not create DNS monitor: ", err) + logger.Fatal("Could not create DNS monitor: ", err) } if err := dnsMonitor.Start(); err != nil { - log.Fatal("Could not start DNS monitor: ", err) + logger.Fatal("Could not start DNS monitor: ", err) } defer dnsMonitor.Stop() - etcdClient, err := signaling.NewEtcdClient(config, "mcu") + etcdClient, err := signaling.NewEtcdClient(logger, config, "mcu") if err != nil { - log.Fatalf("Could not create etcd client: %s", err) + logger.Fatalf("Could not create etcd client: %s", err) } defer func() { if err := etcdClient.Close(); err != nil { - log.Printf("Error while closing etcd client: %s", err) + logger.Printf("Error while closing etcd client: %s", err) } }() - rpcServer, err := signaling.NewGrpcServer(config, version) + rpcServer, err := signaling.NewGrpcServer(stopCtx, config, version) if err != nil { - log.Fatalf("Could not create RPC server: %s", err) + logger.Fatalf("Could not create RPC server: %s", err) } go func() { if err := rpcServer.Run(); err != nil { - log.Fatalf("Could not start RPC server: %s", err) + logger.Fatalf("Could not start RPC server: %s", err) } }() defer rpcServer.Close() - rpcClients, err := signaling.NewGrpcClients(config, etcdClient, dnsMonitor, version) + rpcClients, err := signaling.NewGrpcClients(stopCtx, config, etcdClient, dnsMonitor, version) if err != nil { - log.Fatalf("Could not create RPC clients: %s", err) + logger.Fatalf("Could not create RPC clients: %s", err) } defer rpcClients.Close() r := mux.NewRouter() - hub, err := signaling.NewHub(config, events, rpcServer, rpcClients, etcdClient, r, version) + hub, err := signaling.NewHub(stopCtx, config, events, rpcServer, rpcClients, etcdClient, r, version) if err != nil { - log.Fatal("Could not create hub: ", err) + logger.Fatal("Could not create hub: ", err) } mcuUrl, _ := signaling.GetStringOptionWithEnv(config, "mcu", "url") mcuType, _ := config.GetString("mcu", "type") if mcuType == "" && mcuUrl != "" { - log.Printf("WARNING: Old-style MCU configuration detected with url but no type, defaulting to type %s", signaling.McuTypeJanus) + logger.Printf("WARNING: Old-style MCU configuration detected with url but no type, defaulting to type %s", signaling.McuTypeJanus) mcuType = signaling.McuTypeJanus } else if mcuType == signaling.McuTypeJanus && mcuUrl == "" { - log.Printf("WARNING: Old-style MCU configuration detected with type but no url, disabling") + logger.Printf("WARNING: Old-style MCU configuration detected with type but no url, disabling") mcuType = "" } @@ -246,41 +252,41 @@ func main() { signaling.UnregisterProxyMcuStats() signaling.RegisterJanusMcuStats() case signaling.McuTypeProxy: - mcu, err = signaling.NewMcuProxy(config, etcdClient, rpcClients, dnsMonitor) + mcu, err = signaling.NewMcuProxy(ctx, config, etcdClient, rpcClients, dnsMonitor) signaling.UnregisterJanusMcuStats() signaling.RegisterProxyMcuStats() default: - log.Fatal("Unsupported MCU type: ", mcuType) + logger.Fatal("Unsupported MCU type: ", mcuType) } if err == nil { err = mcu.Start(ctx) if err != nil { - log.Printf("Could not create %s MCU: %s", mcuType, err) + logger.Printf("Could not create %s MCU: %s", mcuType, err) } } if err == nil { break } - log.Printf("Could not initialize %s MCU (%s) will retry in %s", mcuType, err, mcuRetry) + logger.Printf("Could not initialize %s MCU (%s) will retry in %s", mcuType, err, mcuRetry) mcuRetryTimer.Reset(mcuRetry) select { + case <-stopCtx.Done(): + logger.Fatalf("Cancelled") case sig := <-sigChan: switch sig { - case os.Interrupt: - log.Fatalf("Cancelled") case syscall.SIGHUP: - log.Printf("Received SIGHUP, reloading %s", *configFlag) + logger.Printf("Received SIGHUP, reloading %s", *configFlag) if config, err = goconf.ReadConfigFile(*configFlag); err != nil { - log.Printf("Could not read configuration from %s: %s", *configFlag, err) + logger.Printf("Could not read configuration from %s: %s", *configFlag, err) } else { mcuUrl, _ = signaling.GetStringOptionWithEnv(config, "mcu", "url") mcuType, _ = config.GetString("mcu", "type") if mcuType == "" && mcuUrl != "" { - log.Printf("WARNING: Old-style MCU configuration detected with url but no type, defaulting to type %s", signaling.McuTypeJanus) + logger.Printf("WARNING: Old-style MCU configuration detected with url but no type, defaulting to type %s", signaling.McuTypeJanus) mcuType = signaling.McuTypeJanus } else if mcuType == signaling.McuTypeJanus && mcuUrl == "" { - log.Printf("WARNING: Old-style MCU configuration detected with type but no url, disabling") + logger.Printf("WARNING: Old-style MCU configuration detected with type but no url, disabling") mcuType = "" break mcuTypeLoop } @@ -294,7 +300,7 @@ func main() { if mcu != nil { defer mcu.Stop() - log.Printf("Using %s MCU", mcuType) + logger.Printf("Using %s MCU", mcuType) hub.SetMcu(mcu) } } @@ -302,21 +308,23 @@ func main() { go hub.Run() defer hub.Stop() - server, err := signaling.NewBackendServer(config, hub, version) + server, err := signaling.NewBackendServer(stopCtx, config, hub, version) if err != nil { - log.Fatal("Could not create backend server: ", err) + logger.Fatal("Could not create backend server: ", err) } if err := server.Start(r); err != nil { - log.Fatal("Could not start backend server: ", err) + logger.Fatal("Could not start backend server: ", err) } - var listeners Listeners + listeners := Listeners{ + logger: logger, + } if saddr, _ := signaling.GetStringOptionWithEnv(config, "https", "listen"); saddr != "" { cert, _ := config.GetString("https", "certificate") key, _ := config.GetString("https", "key") if cert == "" || key == "" { - log.Fatal("Need a certificate and key for the HTTPS listener") + logger.Fatal("Need a certificate and key for the HTTPS listener") } readTimeout, _ := config.GetInt("https", "readtimeout") @@ -329,10 +337,10 @@ func main() { } for address := range signaling.SplitEntries(saddr, " ") { go func(address string) { - log.Println("Listening on", address) + logger.Println("Listening on", address) listener, err := createTLSListener(address, cert, key) if err != nil { - log.Fatal("Could not start listening: ", err) + logger.Fatal("Could not start listening: ", err) } srv := &http.Server{ Handler: r, @@ -343,7 +351,7 @@ func main() { listeners.Add(listener) if err := srv.Serve(listener); err != nil { if !hub.IsShutdownScheduled() || !errors.Is(err, net.ErrClosed) { - log.Fatal("Could not start server: ", err) + logger.Fatal("Could not start server: ", err) } } }(address) @@ -362,10 +370,10 @@ func main() { for address := range signaling.SplitEntries(addr, " ") { go func(address string) { - log.Println("Listening on", address) + logger.Println("Listening on", address) listener, err := createListener(address) if err != nil { - log.Fatal("Could not start listening: ", err) + logger.Fatal("Could not start listening: ", err) } srv := &http.Server{ Handler: r, @@ -377,7 +385,7 @@ func main() { listeners.Add(listener) if err := srv.Serve(listener); err != nil { if !hub.IsShutdownScheduled() || !errors.Is(err, net.ErrClosed) { - log.Fatal("Could not start server: ", err) + logger.Fatal("Could not start server: ", err) } } }(address) @@ -387,26 +395,26 @@ func main() { loop: for { select { + case <-stopCtx.Done(): + logger.Println("Interrupted") + break loop case sig := <-sigChan: switch sig { - case os.Interrupt: - log.Println("Interrupted") - break loop case syscall.SIGHUP: - log.Printf("Received SIGHUP, reloading %s", *configFlag) + logger.Printf("Received SIGHUP, reloading %s", *configFlag) if config, err := goconf.ReadConfigFile(*configFlag); err != nil { - log.Printf("Could not read configuration from %s: %s", *configFlag, err) + logger.Printf("Could not read configuration from %s: %s", *configFlag, err) } else { - hub.Reload(config) + hub.Reload(stopCtx, config) server.Reload(config) } case syscall.SIGUSR1: - log.Printf("Received SIGUSR1, scheduling server to shutdown") + logger.Printf("Received SIGUSR1, scheduling server to shutdown") hub.ScheduleShutdown() listeners.Close() } case <-hub.ShutdownChannel(): - log.Printf("All clients disconnected, shutting down") + logger.Printf("All clients disconnected, shutting down") break loop } } diff --git a/test_helpers.go b/test_helpers.go index 9c1a7f7..a0ca4ec 100644 --- a/test_helpers.go +++ b/test_helpers.go @@ -24,22 +24,11 @@ package signaling import ( "bytes" "fmt" - "io" "log" "sync" "testing" ) -var ( - prevWriter io.Writer - prevFlags int -) - -func init() { - prevWriter = log.Writer() - prevFlags = log.Flags() -} - type testLogWriter struct { mu sync.Mutex t testing.TB @@ -55,18 +44,6 @@ func (w *testLogWriter) Write(b []byte) (int, error) { return writeTestOutput(w.t, b) } -func CatchLogForTest(t testing.TB) { - t.Cleanup(func() { - log.SetOutput(prevWriter) - log.SetFlags(prevFlags) - }) - - log.SetOutput(&testLogWriter{ - t: t, - }) - log.SetFlags(prevFlags | log.Lmicroseconds | log.Lshortfile) -} - var ( // +checklocks:testLoggersLock testLoggers = map[testing.TB]Logger{} diff --git a/throttle.go b/throttle.go index 6e78582..a19aa25 100644 --- a/throttle.go +++ b/throttle.go @@ -24,7 +24,6 @@ package signaling import ( "context" "errors" - "log" "net" "strconv" "sync" @@ -277,7 +276,8 @@ func (t *memoryThrottler) CheckBruteforce(ctx context.Context, client string, ac if l >= maxBruteforceAttempts { delta := now.Sub(entries[l-maxBruteforceAttempts].ts) if delta <= maxBruteforceDurationThreshold { - log.Printf("Detected bruteforce attempt on \"%s\" from %s", action, client) + logger := LoggerFromContext(ctx) + logger.Printf("Detected bruteforce attempt on \"%s\" from %s", action, client) statsThrottleBruteforceTotal.WithLabelValues(action).Inc() return doThrottle, ErrBruteforceDetected } @@ -301,7 +301,8 @@ func (t *memoryThrottler) throttle(ctx context.Context, client string, action st } count := t.addEntry(client, action, entry) delay := t.getDelay(count - 1) - log.Printf("Failed attempt on \"%s\" from %s, throttling by %s", action, client, delay) + logger := LoggerFromContext(ctx) + logger.Printf("Failed attempt on \"%s\" from %s, throttling by %s", action, client, delay) statsThrottleDelayedTotal.WithLabelValues(action, strconv.FormatInt(delay.Milliseconds(), 10)).Inc() t.doDelay(ctx, delay) } diff --git a/throttle_test.go b/throttle_test.go index 62ebf7f..07f3520 100644 --- a/throttle_test.go +++ b/throttle_test.go @@ -71,7 +71,8 @@ func TestThrottler(t *testing.T) { assert := assert.New(t) th := newMemoryThrottlerForTest(t) - ctx := context.Background() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") assert.NoError(err) @@ -104,7 +105,8 @@ func TestThrottlerIPv6(t *testing.T) { assert := assert.New(t) th := newMemoryThrottlerForTest(t) - ctx := context.Background() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) // Make sure full /64 subnets are throttled for IPv6. throttle1, err := th.CheckBruteforce(ctx, "2001:db8:abcd:0012::1", "action1") @@ -140,7 +142,8 @@ func TestThrottler_Bruteforce(t *testing.T) { assert := assert.New(t) th := newMemoryThrottlerForTest(t) - ctx := context.Background() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) delay := 100 * time.Millisecond for range maxBruteforceAttempts { @@ -167,7 +170,8 @@ func TestThrottler_Cleanup(t *testing.T) { th, ok := throttler.(*memoryThrottler) require.True(t, ok, "required memoryThrottler, got %T", throttler) - ctx := context.Background() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") assert.NoError(err) @@ -220,7 +224,8 @@ func TestThrottler_ExpirePartial(t *testing.T) { assert := assert.New(t) th := newMemoryThrottlerForTest(t) - ctx := context.Background() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") assert.NoError(err) @@ -251,7 +256,8 @@ func TestThrottler_ExpireAll(t *testing.T) { assert := assert.New(t) th := newMemoryThrottlerForTest(t) - ctx := context.Background() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") assert.NoError(err) @@ -282,7 +288,8 @@ func TestThrottler_Negative(t *testing.T) { assert := assert.New(t) th := newMemoryThrottlerForTest(t) - ctx := context.Background() + logger := NewLoggerForTest(t) + ctx := NewLoggerContext(t.Context(), logger) delay := 100 * time.Millisecond for range maxBruteforceAttempts * 10 { diff --git a/transient_data_test.go b/transient_data_test.go index ca554fc..2ee8331 100644 --- a/transient_data_test.go +++ b/transient_data_test.go @@ -139,7 +139,6 @@ func Test_TransientDataDeadlock(t *testing.T) { func Test_TransientMessages(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) hub, _, _, server := CreateHubForTest(t) diff --git a/virtualsession.go b/virtualsession.go index 4ec31cf..57ef6aa 100644 --- a/virtualsession.go +++ b/virtualsession.go @@ -24,7 +24,6 @@ package signaling import ( "context" "encoding/json" - "log" "net/url" "sync/atomic" @@ -38,6 +37,7 @@ const ( ) type VirtualSession struct { + logger Logger hub *Hub session *ClientSession privateId PrivateSessionId @@ -61,6 +61,7 @@ func GetVirtualSessionId(session Session, sessionId PublicSessionId) PublicSessi func NewVirtualSession(session *ClientSession, privateId PrivateSessionId, publicId PublicSessionId, data *SessionIdData, msg *AddSessionInternalClientMessage) (*VirtualSession, error) { result := &VirtualSession{ + logger: session.hub.logger, hub: session.hub, session: session, privateId: privateId, @@ -154,7 +155,7 @@ func (s *VirtualSession) SetRoom(room *Room) { s.room.Store(room) if room != nil { if err := s.hub.roomSessions.SetRoomSession(s, RoomSessionId(s.PublicId())); err != nil { - log.Printf("Error adding virtual room session %s: %s", s.PublicId(), err) + s.logger.Printf("Error adding virtual room session %s: %s", s.PublicId(), err) } } else { s.hub.roomSessions.DeleteRoomSession(s) @@ -191,7 +192,8 @@ func (s *VirtualSession) CloseWithFeedback(session Session, message *ClientMessa } func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, message *ClientMessage) { - ctx, cancel := context.WithTimeout(context.Background(), s.hub.backendTimeout) + ctx := NewLoggerContext(context.Background(), s.logger) + ctx, cancel := context.WithTimeout(ctx, s.hub.backendTimeout) defer cancel() if options := s.Options(); options != nil && options.ActorId != "" && options.ActorType != "" { @@ -203,7 +205,7 @@ func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, messa var response BackendClientResponse if err := s.hub.backend.PerformJSONRequest(ctx, s.ParsedBackendOcsUrl(), request, &response); err != nil { virtualSessionId := GetVirtualSessionId(s.session, s.PublicId()) - log.Printf("Could not leave virtual session %s at backend %s: %s", virtualSessionId, s.BackendUrl(), err) + s.logger.Printf("Could not leave virtual session %s at backend %s: %s", virtualSessionId, s.BackendUrl(), err) if session != nil && message != nil { reply := message.NewErrorServerMessage(NewError("remove_failed", "Could not remove virtual session from backend.")) session.SendMessage(reply) @@ -214,7 +216,7 @@ func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, messa if response.Type == "error" { virtualSessionId := GetVirtualSessionId(s.session, s.PublicId()) if session != nil && message != nil && (response.Error == nil || response.Error.Code != "no_such_room") { - log.Printf("Could not leave virtual session %s at backend %s: %+v", virtualSessionId, s.BackendUrl(), response.Error) + s.logger.Printf("Could not leave virtual session %s at backend %s: %+v", virtualSessionId, s.BackendUrl(), response.Error) reply := message.NewErrorServerMessage(NewError("remove_failed", response.Error.Error())) session.SendMessage(reply) } @@ -228,7 +230,7 @@ func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, messa var response BackendClientSessionResponse err := s.hub.backend.PerformJSONRequest(ctx, s.ParsedBackendOcsUrl(), request, &response) if err != nil { - log.Printf("Could not remove virtual session %s from backend %s: %s", s.PublicId(), s.BackendUrl(), err) + s.logger.Printf("Could not remove virtual session %s from backend %s: %s", s.PublicId(), s.BackendUrl(), err) if session != nil && message != nil { reply := message.NewErrorServerMessage(NewError("remove_failed", "Could not remove virtual session from backend.")) session.SendMessage(reply) @@ -291,7 +293,7 @@ func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) { message.Message.Event.Type == "disinvite" && message.Message.Event.Disinvite != nil && message.Message.Event.Disinvite.RoomId == room.Id() { - log.Printf("Virtual session %s was disinvited from room %s, hanging up", s.PublicId(), room.Id()) + s.logger.Printf("Virtual session %s was disinvited from room %s, hanging up", s.PublicId(), room.Id()) payload := api.StringMap{ "type": "hangup", "hangup": map[string]string{ @@ -300,7 +302,7 @@ func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) { } data, err := json.Marshal(payload) if err != nil { - log.Printf("could not marshal control payload %+v: %s", payload, err) + s.logger.Printf("could not marshal control payload %+v: %s", payload, err) return } diff --git a/virtualsession_test.go b/virtualsession_test.go index 3e7ddf1..cf01ae4 100644 --- a/virtualsession_test.go +++ b/virtualsession_test.go @@ -35,7 +35,6 @@ import ( func TestVirtualSession(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -219,7 +218,6 @@ func TestVirtualSession(t *testing.T) { func TestVirtualSessionActorInformation(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -429,7 +427,6 @@ func checkHasEntryWithInCall(t *testing.T, message *RoomEventServerMessage, sess func TestVirtualSessionCustomInCall(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t) @@ -571,7 +568,6 @@ func TestVirtualSessionCustomInCall(t *testing.T) { func TestVirtualSessionCleanup(t *testing.T) { t.Parallel() - CatchLogForTest(t) require := require.New(t) assert := assert.New(t) hub, _, _, server := CreateHubForTest(t)