diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index fbce48d..25ab73b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -8,6 +8,7 @@ on: - '.golangci.yml' - '**.go' - 'go.*' + - 'Makefile' pull_request: branches: [ master ] paths: @@ -15,6 +16,7 @@ on: - '.golangci.yml' - '**.go' - 'go.*' + - 'Makefile' permissions: contents: read @@ -53,6 +55,20 @@ jobs: run: | GOEXPERIMENT=synctest go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -test ./... + checklocks: + name: checklocks + runs-on: ubuntu-latest + continue-on-error: true + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-go@v6 + with: + go-version: "1.24" + + - name: checklocks + run: | + make checklocks + dependencies: name: dependencies runs-on: ubuntu-latest diff --git a/Makefile b/Makefile index b5843dd..de8898e 100644 --- a/Makefile +++ b/Makefile @@ -85,6 +85,9 @@ $(GOPATHBIN)/protoc-gen-go: go.mod go.sum $(GOPATHBIN)/protoc-gen-go-grpc: go.mod go.sum $(GO) install google.golang.org/grpc/cmd/protoc-gen-go-grpc +$(GOPATHBIN)/checklocks: go.mod go.sum + $(GO) install gvisor.dev/gvisor/tools/checklocks/cmd/checklocks@go + continentmap.go: $(CURDIR)/scripts/get_continent_map.py $@ @@ -108,6 +111,9 @@ vet: test: vet GOEXPERIMENT=synctest $(GO) test -timeout $(TIMEOUT) $(TESTARGS) ./... +checklocks: $(GOPATHBIN)/checklocks + GOEXPERIMENT=synctest go vet -vettool=$(GOPATHBIN)/checklocks ./... + cover: vet rm -f cover.out && \ GOEXPERIMENT=synctest $(GO) test -timeout $(TIMEOUT) -coverprofile cover.out ./... diff --git a/api_signaling.go b/api_signaling.go index 960a987..af5c86b 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -52,7 +52,7 @@ const ( ) var ( - ErrNoSdp = NewError("no_sdp", "Payload does not contain a SDP.") + ErrNoSdp = NewError("no_sdp", "Payload does not contain a SDP.") // +checklocksignore: Global readonly variable. ErrInvalidSdp = NewError("invalid_sdp", "Payload does not contain a valid SDP.") ErrNoCandidate = NewError("no_candidate", "Payload does not contain a candidate.") diff --git a/async_events.go b/async_events.go index 308b3cf..d8bb0c9 100644 --- a/async_events.go +++ b/async_events.go @@ -72,6 +72,7 @@ func NewAsyncEvents(url string) (AsyncEvents, error) { type asyncBackendRoomSubscriber struct { mu sync.Mutex + // +checklocks:mu listeners map[AsyncBackendRoomEventListener]bool } @@ -107,6 +108,7 @@ func (s *asyncBackendRoomSubscriber) removeListener(listener AsyncBackendRoomEve type asyncRoomSubscriber struct { mu sync.Mutex + // +checklocks:mu listeners map[AsyncRoomEventListener]bool } @@ -142,6 +144,7 @@ func (s *asyncRoomSubscriber) removeListener(listener AsyncRoomEventListener) bo type asyncUserSubscriber struct { mu sync.Mutex + // +checklocks:mu listeners map[AsyncUserEventListener]bool } @@ -177,6 +180,7 @@ func (s *asyncUserSubscriber) removeListener(listener AsyncUserEventListener) bo type asyncSessionSubscriber struct { mu sync.Mutex + // +checklocks:mu listeners map[AsyncSessionEventListener]bool } diff --git a/async_events_nats.go b/async_events_nats.go index 4e52487..b5c3ae5 100644 --- a/async_events_nats.go +++ b/async_events_nats.go @@ -231,10 +231,14 @@ type asyncEventsNats struct { mu sync.Mutex client NatsClient + // +checklocks:mu backendRoomSubscriptions map[string]*asyncBackendRoomSubscriberNats - roomSubscriptions map[string]*asyncRoomSubscriberNats - userSubscriptions map[string]*asyncUserSubscriberNats - sessionSubscriptions map[string]*asyncSessionSubscriberNats + // +checklocks:mu + roomSubscriptions map[string]*asyncRoomSubscriberNats + // +checklocks:mu + userSubscriptions map[string]*asyncUserSubscriberNats + // +checklocks:mu + sessionSubscriptions map[string]*asyncSessionSubscriberNats } func NewAsyncEventsNats(client NatsClient) (AsyncEvents, error) { diff --git a/backend_configuration.go b/backend_configuration.go index 2da8703..bc604e6 100644 --- a/backend_configuration.go +++ b/backend_configuration.go @@ -56,7 +56,8 @@ type Backend struct { sessionLimit uint64 sessionsLock sync.Mutex - sessions map[PublicSessionId]bool + // +checklocks:sessionsLock + sessions map[PublicSessionId]bool counted bool } @@ -170,7 +171,8 @@ type BackendStorage interface { } type backendStorageCommon struct { - mu sync.RWMutex + mu sync.RWMutex + // +checklocks:mu backends map[string][]*Backend } diff --git a/backend_configuration_stats_prometheus.go b/backend_configuration_stats_prometheus.go index d19d7f9..13664ad 100644 --- a/backend_configuration_stats_prometheus.go +++ b/backend_configuration_stats_prometheus.go @@ -32,7 +32,7 @@ var ( Name: "session_limit", Help: "The session limit of a backend", }, []string{"backend"}) - statsBackendLimitExceededTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + statsBackendLimitExceededTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ // +checklocksignore: Global readonly variable. Namespace: "signaling", Subsystem: "backend", Name: "session_limit_exceeded_total", diff --git a/backend_configuration_test.go b/backend_configuration_test.go index f882acd..d5d7bdd 100644 --- a/backend_configuration_test.go +++ b/backend_configuration_test.go @@ -573,7 +573,9 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) { assert.Equal(secret3, string(backends[0].secret)) } + storage.mu.RLock() _, found := storage.backends["domain1.invalid"] + storage.mu.RUnlock() assert.False(found, "Should have removed host information") } @@ -806,6 +808,8 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) { <-ch checkStatsValue(t, statsBackendsCurrent, 0) + storage.mu.RLock() _, found := storage.backends["domain1.invalid"] + storage.mu.RUnlock() assert.False(found, "Should have removed host information") } diff --git a/backend_server_test.go b/backend_server_test.go index 7e9f828..5d9eb95 100644 --- a/backend_server_test.go +++ b/backend_server_test.go @@ -628,7 +628,7 @@ func RunTestBackendServer_RoomUpdate(t *testing.T) { emptyProperties := json.RawMessage("{}") backend := hub.backend.GetBackend(u) require.NotNil(backend, "Did not find backend") - room, err := hub.createRoom(roomId, emptyProperties, backend) + room, err := hub.CreateRoom(roomId, emptyProperties, backend) require.NoError(err, "Could not create room") defer room.Close() @@ -696,7 +696,7 @@ func RunTestBackendServer_RoomDelete(t *testing.T) { emptyProperties := json.RawMessage("{}") backend := hub.backend.GetBackend(u) require.NotNil(backend, "Did not find backend") - _, err = hub.createRoom(roomId, emptyProperties, backend) + _, err = hub.CreateRoom(roomId, emptyProperties, backend) require.NoError(err) userid := "test-userid" diff --git a/backend_storage_static.go b/backend_storage_static.go index 1fb77b4..f6e32a4 100644 --- a/backend_storage_static.go +++ b/backend_storage_static.go @@ -151,6 +151,7 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) func (s *backendStorageStatic) Close() { } +// +checklocks:s.mu func (s *backendStorageStatic) RemoveBackendsForHost(host string, seen map[string]seenState) { if oldBackends := s.backends[host]; len(oldBackends) > 0 { deleted := 0 @@ -185,6 +186,7 @@ const ( seenDeleted ) +// +checklocks:s.mu func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend, seen map[string]seenState) { for existingIndex, existingBackend := range s.backends[host] { found := false diff --git a/capabilities.go b/capabilities.go index 7f16344..c5bb1e0 100644 --- a/capabilities.go +++ b/capabilities.go @@ -57,16 +57,22 @@ const ( ) var ( - ErrUnexpectedHttpStatus = errors.New("unexpected_http_status") + ErrUnexpectedHttpStatus = errors.New("unexpected_http_status") // +checklocksignore: Global readonly variable. ) type capabilitiesEntry struct { - c *Capabilities - mu sync.RWMutex - nextUpdate time.Time - etag string + mu sync.RWMutex + // +checklocks:mu + c *Capabilities + + // +checklocks:mu + nextUpdate time.Time + // +checklocks:mu + etag string + // +checklocks:mu mustRevalidate bool - capabilities api.StringMap + // +checklocks:mu + capabilities api.StringMap } func newCapabilitiesEntry(c *Capabilities) *capabilitiesEntry { @@ -82,10 +88,12 @@ func (e *capabilitiesEntry) valid(now time.Time) bool { return e.validLocked(now) } +// +checklocksread:e.mu func (e *capabilitiesEntry) validLocked(now time.Time) bool { return e.nextUpdate.After(now) } +// +checklocks:e.mu func (e *capabilitiesEntry) updateRequest(r *http.Request) { if e.etag != "" { r.Header.Set("If-None-Match", e.etag) @@ -99,6 +107,7 @@ func (e *capabilitiesEntry) invalidate() { e.nextUpdate = time.Now() } +// +checklocks:e.mu func (e *capabilitiesEntry) errorIfMustRevalidate(err error) (bool, error) { if !e.mustRevalidate { return false, nil @@ -238,9 +247,11 @@ type Capabilities struct { // Can be overwritten by tests. getNow func() time.Time - version string - pool *HttpClientPool - entries map[string]*capabilitiesEntry + version string + pool *HttpClientPool + // +checklocks:mu + entries map[string]*capabilitiesEntry + // +checklocks:mu nextInvalidate map[string]time.Time buffers BufferPool diff --git a/channel_waiter.go b/channel_waiter.go index 20b0883..dde6cfd 100644 --- a/channel_waiter.go +++ b/channel_waiter.go @@ -26,8 +26,10 @@ import ( ) type ChannelWaiters struct { - mu sync.RWMutex - id uint64 + mu sync.RWMutex + // +checklocks:mu + id uint64 + // +checklocks:mu waiters map[uint64]chan struct{} } diff --git a/client.go b/client.go index 2d06ab8..53f8967 100644 --- a/client.go +++ b/client.go @@ -130,7 +130,8 @@ type Client struct { logRTT bool handlerMu sync.RWMutex - handler ClientHandler + // +checklocks:handlerMu + handler ClientHandler session atomic.Pointer[Session] sessionId atomic.Pointer[PublicSessionId] diff --git a/client/main.go b/client/main.go index 7145fee..3102cff 100644 --- a/client/main.go +++ b/client/main.go @@ -113,7 +113,7 @@ type MessagePayload struct { } type SignalingClient struct { - readyWg *sync.WaitGroup + readyWg *sync.WaitGroup // +checklocksignore: Only written to from constructor. cookie *signaling.SessionIdCodec conn *websocket.Conn @@ -123,10 +123,13 @@ type SignalingClient struct { stopChan chan struct{} - lock sync.Mutex + lock sync.Mutex + // +checklocks:lock privateSessionId signaling.PrivateSessionId - publicSessionId signaling.PublicSessionId - userId string + // +checklocks:lock + publicSessionId signaling.PublicSessionId + // +checklocks:lock + userId string } func NewSignalingClient(cookie *signaling.SessionIdCodec, url string, stats *Stats, readyWg *sync.WaitGroup, doneWg *sync.WaitGroup) (*SignalingClient, error) { diff --git a/clientsession.go b/clientsession.go index 0f5f420..dd52e88 100644 --- a/clientsession.go +++ b/clientsession.go @@ -70,9 +70,11 @@ type ClientSession struct { parseUserData func() (api.StringMap, error) - inCall Flags + inCall Flags + // +checklocks:mu supportsPermissions bool - permissions map[Permission]bool + // +checklocks:mu + permissions map[Permission]bool backend *Backend backendUrl string @@ -80,30 +82,40 @@ type ClientSession struct { mu sync.Mutex + // +checklocks:mu client HandlerClient room atomic.Pointer[Room] roomJoinTime atomic.Int64 federation atomic.Pointer[FederationClient] roomSessionIdLock sync.RWMutex - roomSessionId RoomSessionId + // +checklocks:roomSessionIdLock + roomSessionId RoomSessionId - publisherWaiters ChannelWaiters + publisherWaiters ChannelWaiters // +checklocksignore - publishers map[StreamType]McuPublisher + // +checklocks:mu + publishers map[StreamType]McuPublisher + // +checklocks:mu subscribers map[StreamId]McuSubscriber - pendingClientMessages []*ServerMessage - hasPendingChat bool + // +checklocks:mu + pendingClientMessages []*ServerMessage + // +checklocks:mu + hasPendingChat bool + // +checklocks:mu hasPendingParticipantsUpdate bool + // +checklocks:mu virtualSessions map[*VirtualSession]bool - seenJoinedLock sync.Mutex + seenJoinedLock sync.Mutex + // +checklocks:seenJoinedLock seenJoinedEvents map[PublicSessionId]bool responseHandlersLock sync.Mutex - responseHandlers map[string]ResponseHandlerFunc + // +checklocks:responseHandlersLock + responseHandlers map[string]ResponseHandlerFunc } func NewClientSession(hub *Hub, privateId PrivateSessionId, publicId PublicSessionId, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) { @@ -197,6 +209,19 @@ func (s *ClientSession) HasPermission(permission Permission) bool { return s.hasPermissionLocked(permission) } +func (s *ClientSession) GetPermissions() []Permission { + s.mu.Lock() + defer s.mu.Unlock() + + result := make([]Permission, len(s.permissions)) + for p, ok := range s.permissions { + if ok { + result = append(result, p) + } + } + return result +} + // HasAnyPermission checks if the session has one of the passed permissions. func (s *ClientSession) HasAnyPermission(permission ...Permission) bool { if len(permission) == 0 { @@ -209,16 +234,16 @@ func (s *ClientSession) HasAnyPermission(permission ...Permission) bool { return s.hasAnyPermissionLocked(permission...) } +// +checklocks:s.mu func (s *ClientSession) hasAnyPermissionLocked(permission ...Permission) bool { if len(permission) == 0 { return false } - return slices.ContainsFunc(permission, func(p Permission) bool { - return s.hasPermissionLocked(p) - }) + return slices.ContainsFunc(permission, s.hasPermissionLocked) } +// +checklocks:s.mu func (s *ClientSession) hasPermissionLocked(permission Permission) bool { if !s.supportsPermissions { // Old-style session that doesn't receive permissions from Nextcloud. @@ -340,6 +365,7 @@ func (s *ClientSession) getRoomJoinTime() time.Time { return time.Unix(0, t) } +// +checklocks:s.mu func (s *ClientSession) releaseMcuObjects() { if len(s.publishers) > 0 { go func(publishers map[StreamType]McuPublisher) { @@ -489,6 +515,7 @@ func (s *ClientSession) LeaveRoomWithMessage(notify bool, message *ClientMessage return s.doLeaveRoom(notify) } +// +checklocks:s.mu func (s *ClientSession) doLeaveRoom(notify bool) *Room { room := s.GetRoom() if room == nil { @@ -543,6 +570,7 @@ func (s *ClientSession) ClearClient(client HandlerClient) { s.clearClientLocked(client) } +// +checklocks:s.mu func (s *ClientSession) clearClientLocked(client HandlerClient) { if s.client == nil { return @@ -563,6 +591,7 @@ func (s *ClientSession) GetClient() HandlerClient { return s.getClientUnlocked() } +// +checklocks:s.mu func (s *ClientSession) getClientUnlocked() HandlerClient { return s.client } @@ -589,6 +618,7 @@ func (s *ClientSession) SetClient(client HandlerClient) HandlerClient { return prev } +// +checklocks:s.mu func (s *ClientSession) sendOffer(client McuClient, sender PublicSessionId, streamType StreamType, offer api.StringMap) { offer_message := &AnswerOfferMessage{ To: s.PublicId(), @@ -617,6 +647,7 @@ func (s *ClientSession) sendOffer(client McuClient, sender PublicSessionId, stre s.sendMessageUnlocked(response_message) } +// +checklocks:s.mu func (s *ClientSession) sendCandidate(client McuClient, sender PublicSessionId, streamType StreamType, candidate any) { candidate_message := &AnswerOfferMessage{ To: s.PublicId(), @@ -647,6 +678,7 @@ func (s *ClientSession) sendCandidate(client McuClient, sender PublicSessionId, s.sendMessageUnlocked(response_message) } +// +checklocks:s.mu func (s *ClientSession) sendMessageUnlocked(message *ServerMessage) bool { if c := s.getClientUnlocked(); c != nil { if c.SendMessage(message) { @@ -768,6 +800,7 @@ func (e *PermissionError) Error() string { return fmt.Sprintf("permission \"%s\" not found", e.permission) } +// +checklocks:s.mu func (s *ClientSession) isSdpAllowedToSendLocked(sdp *sdp.SessionDescription) (MediaType, error) { if sdp == nil { // Should have already been checked when data was validated. @@ -832,6 +865,7 @@ func (s *ClientSession) CheckOfferType(streamType StreamType, data *MessageClien return s.checkOfferTypeLocked(streamType, data) } +// +checklocks:s.mu func (s *ClientSession) checkOfferTypeLocked(streamType StreamType, data *MessageClientMessageData) (MediaType, error) { if streamType == StreamTypeScreen { if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) { @@ -893,6 +927,8 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea if err != nil { return nil, err } + s.mu.Lock() + defer s.mu.Unlock() if s.publishers == nil { s.publishers = make(map[StreamType]McuPublisher) } @@ -915,6 +951,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea return publisher, nil } +// +checklocks:s.mu func (s *ClientSession) getPublisherLocked(streamType StreamType) McuPublisher { return s.publishers[streamType] } @@ -1120,6 +1157,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) { s.SendMessage(serverMessage) } +// +checklocks:s.mu func (s *ClientSession) storePendingMessage(message *ServerMessage) { if message.IsChatRefresh() { if s.hasPendingChat { diff --git a/concurrentmap.go b/concurrentmap.go index 8bc83a4..255d2a1 100644 --- a/concurrentmap.go +++ b/concurrentmap.go @@ -27,7 +27,8 @@ import ( type ConcurrentMap[K comparable, V any] struct { mu sync.RWMutex - d map[K]V + // +checklocks:mu + d map[K]V // +checklocksignore: Not supported yet, see https://github.com/google/gvisor/issues/11671 } func (m *ConcurrentMap[K, V]) Set(key K, value V) { diff --git a/dns_monitor.go b/dns_monitor.go index d324f04..dcfba64 100644 --- a/dns_monitor.go +++ b/dns_monitor.go @@ -57,11 +57,28 @@ type dnsMonitorEntry struct { hostname string hostIP net.IP - mu sync.Mutex - ips []net.IP + mu sync.Mutex + // +checklocks:mu + ips []net.IP + // +checklocks:mu entries map[*DnsMonitorEntry]bool } +func (e *dnsMonitorEntry) clearRemoved() bool { + e.mu.Lock() + defer e.mu.Unlock() + + deleted := false + for entry := range e.entries { + if entry.entry.Load() == nil { + delete(e.entries, entry) + deleted = true + } + } + + return deleted && len(e.entries) == 0 +} + func (e *dnsMonitorEntry) setIPs(ips []net.IP, fromIP bool) { e.mu.Lock() defer e.mu.Unlock() @@ -134,6 +151,7 @@ func (e *dnsMonitorEntry) removeEntry(entry *DnsMonitorEntry) bool { return len(e.entries) == 0 } +// +checklocks:e.mu func (e *dnsMonitorEntry) runCallbacks(all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) { for entry := range e.entries { entry.callback(entry, all, add, keep, remove) @@ -247,7 +265,7 @@ func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) { m.hasRemoved.Store(true) return } - defer m.mu.Unlock() + defer m.mu.Unlock() // +checklocksforce: only executed if the TryLock above succeeded. e, found := m.hostnames[oldEntry.hostname] if !found { @@ -268,15 +286,7 @@ func (m *DnsMonitor) clearRemoved() { defer m.mu.Unlock() for hostname, entry := range m.hostnames { - deleted := false - for e := range entry.entries { - if e.entry.Load() == nil { - delete(entry.entries, e) - deleted = true - } - } - - if deleted && len(entry.entries) == 0 { + if entry.clearRemoved() { delete(m.hostnames, hostname) } } diff --git a/dns_monitor_test.go b/dns_monitor_test.go index fbb7762..daa92ef 100644 --- a/dns_monitor_test.go +++ b/dns_monitor_test.go @@ -38,6 +38,7 @@ import ( type mockDnsLookup struct { sync.RWMutex + // +checklocks:RWMutex ips map[string][]net.IP } @@ -118,14 +119,16 @@ func (r *dnsMonitorReceiverRecord) String() string { } var ( - expectNone = &dnsMonitorReceiverRecord{} + expectNone = &dnsMonitorReceiverRecord{} // +checklocksignore: Global readonly variable. ) type dnsMonitorReceiver struct { sync.Mutex - t *testing.T + t *testing.T + // +checklocks:Mutex expected *dnsMonitorReceiverRecord + // +checklocks:Mutex received *dnsMonitorReceiverRecord } @@ -328,13 +331,15 @@ func TestDnsMonitorNoLookupIfEmpty(t *testing.T) { type deadlockMonitorReceiver struct { t *testing.T - monitor *DnsMonitor + monitor *DnsMonitor // +checklocksignore: Only written to from constructor. mu sync.RWMutex wg sync.WaitGroup - entry *DnsMonitorEntry - started chan struct{} + // +checklocks:mu + entry *DnsMonitorEntry + started chan struct{} + // +checklocks:mu triggered bool closed atomic.Bool } diff --git a/etcd_client.go b/etcd_client.go index 459f14b..1780679 100644 --- a/etcd_client.go +++ b/etcd_client.go @@ -55,8 +55,9 @@ type EtcdClientWatcher interface { type EtcdClient struct { compatSection string - mu sync.Mutex - client atomic.Value + mu sync.Mutex + client atomic.Value + // +checklocks:mu listeners map[EtcdClientListener]bool } diff --git a/federation.go b/federation.go index 458b265..f491616 100644 --- a/federation.go +++ b/federation.go @@ -98,6 +98,7 @@ type FederationClient struct { resumeId PrivateSessionId hello atomic.Pointer[HelloServerMessage] + // +checklocks:helloMu pendingMessages []*ClientMessage closeOnLeave atomic.Bool diff --git a/geoip.go b/geoip.go index ec461e5..1051324 100644 --- a/geoip.go +++ b/geoip.go @@ -32,7 +32,7 @@ import ( "net/url" "os" "strings" - "sync" + "sync/atomic" "time" "github.com/dlintw/goconf" @@ -59,12 +59,11 @@ type GeoLookup struct { url string isFile bool client http.Client - mu sync.Mutex - lastModifiedHeader string - lastModifiedTime time.Time + lastModifiedHeader atomic.Value + lastModifiedTime atomic.Int64 - reader *maxminddb.Reader + reader atomic.Pointer[maxminddb.Reader] } func NewGeoLookupFromUrl(url string) (*GeoLookup, error) { @@ -87,12 +86,9 @@ func NewGeoLookupFromFile(filename string) (*GeoLookup, error) { } func (g *GeoLookup) Close() { - g.mu.Lock() - if g.reader != nil { - g.reader.Close() - g.reader = nil + if reader := g.reader.Swap(nil); reader != nil { + reader.Close() } - g.mu.Unlock() } func (g *GeoLookup) Update() error { @@ -109,7 +105,7 @@ func (g *GeoLookup) updateFile() error { return err } - if info.ModTime().Equal(g.lastModifiedTime) { + if info.ModTime().UnixNano() == g.lastModifiedTime.Load() { return nil } @@ -125,13 +121,11 @@ func (g *GeoLookup) updateFile() error { metadata := reader.Metadata log.Printf("Using %s GeoIP database from %s (built on %s)", metadata.DatabaseType, g.url, time.Unix(int64(metadata.BuildEpoch), 0).UTC()) - g.mu.Lock() - if g.reader != nil { - g.reader.Close() + if old := g.reader.Swap(reader); old != nil { + old.Close() } - g.reader = reader - g.lastModifiedTime = info.ModTime() - g.mu.Unlock() + + g.lastModifiedTime.Store(info.ModTime().UnixNano()) return nil } @@ -140,8 +134,8 @@ func (g *GeoLookup) updateUrl() error { if err != nil { return err } - if g.lastModifiedHeader != "" { - request.Header.Add("If-Modified-Since", g.lastModifiedHeader) + if header := g.lastModifiedHeader.Load(); header != nil { + request.Header.Add("If-Modified-Since", header.(string)) } response, err := g.client.Do(request) if err != nil { @@ -210,13 +204,11 @@ func (g *GeoLookup) updateUrl() error { metadata := reader.Metadata log.Printf("Using %s GeoIP database from %s (built on %s)", metadata.DatabaseType, g.url, time.Unix(int64(metadata.BuildEpoch), 0).UTC()) - g.mu.Lock() - if g.reader != nil { - g.reader.Close() + if old := g.reader.Swap(reader); old != nil { + old.Close() } - g.reader = reader - g.lastModifiedHeader = response.Header.Get("Last-Modified") - g.mu.Unlock() + + g.lastModifiedHeader.Store(response.Header.Get("Last-Modified")) return nil } @@ -227,14 +219,12 @@ func (g *GeoLookup) LookupCountry(ip net.IP) (string, error) { } `maxminddb:"country"` } - g.mu.Lock() - if g.reader == nil { - g.mu.Unlock() + reader := g.reader.Load() + if reader == nil { return "", ErrDatabaseNotInitialized } - err := g.reader.Lookup(ip, &record) - g.mu.Unlock() - if err != nil { + + if err := reader.Lookup(ip, &record); err != nil { return "", err } diff --git a/grpc_client.go b/grpc_client.go index e92076d..38a4270 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -414,14 +414,18 @@ type GrpcClients struct { mu sync.RWMutex version string + // +checklocks:mu clientsMap map[string]*grpcClientsList - clients []*GrpcClient + // +checklocks:mu + clients []*GrpcClient - dnsMonitor *DnsMonitor + dnsMonitor *DnsMonitor + // +checklocks:mu dnsDiscovery bool - etcdClient *EtcdClient - targetPrefix string + etcdClient *EtcdClient // +checklocksignore: Only written to from constructor. + targetPrefix string + // +checklocks:mu targetInformation map[string]*GrpcTargetInformationEtcd dialOptions atomic.Value // []grpc.DialOption creds credentials.TransportCredentials @@ -432,7 +436,7 @@ type GrpcClients struct { selfCheckWaitGroup sync.WaitGroup closeCtx context.Context - closeFunc context.CancelFunc + closeFunc context.CancelFunc // +checklocksignore: No locking necessary. } func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient, dnsMonitor *DnsMonitor, version string) (*GrpcClients, error) { @@ -755,6 +759,9 @@ func (c *GrpcClients) onLookup(entry *DnsMonitorEntry, all []net.IP, added []net } func (c *GrpcClients) loadTargetsEtcd(config *goconf.ConfigFile, fromReload bool, opts ...grpc.DialOption) error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.etcdClient.IsConfigured() { return fmt.Errorf("no etcd endpoints configured") } @@ -894,6 +901,7 @@ func (c *GrpcClients) EtcdKeyDeleted(client *EtcdClient, key string, prevValue [ c.removeEtcdClientLocked(key) } +// +checklocks:c.mu func (c *GrpcClients) removeEtcdClientLocked(key string) { info, found := c.targetInformation[key] if !found { diff --git a/grpc_stats_prometheus.go b/grpc_stats_prometheus.go index fa37bdc..4697fc3 100644 --- a/grpc_stats_prometheus.go +++ b/grpc_stats_prometheus.go @@ -26,7 +26,7 @@ import ( ) var ( - statsGrpcClients = prometheus.NewGauge(prometheus.GaugeOpts{ + statsGrpcClients = prometheus.NewGauge(prometheus.GaugeOpts{ // +checklocksignore: Global readonly variable. Namespace: "signaling", Subsystem: "grpc", Name: "clients", diff --git a/http_client_pool.go b/http_client_pool.go index 65c19dc..b257962 100644 --- a/http_client_pool.go +++ b/http_client_pool.go @@ -79,9 +79,10 @@ type HttpClientPool struct { mu sync.Mutex transport *http.Transport - clients map[string]*Pool + // +checklocks:mu + clients map[string]*Pool - maxConcurrentRequestsPerHost int + maxConcurrentRequestsPerHost int // +checklocksignore: Only written to from constructor. } func NewHttpClientPool(maxConcurrentRequestsPerHost int, skipVerify bool) (*HttpClientPool, error) { diff --git a/hub.go b/hub.go index 29fb003..90abc74 100644 --- a/hub.go +++ b/hub.go @@ -162,13 +162,17 @@ type Hub struct { mu sync.RWMutex ru sync.RWMutex - sid atomic.Uint64 - clients map[uint64]HandlerClient + sid atomic.Uint64 + // +checklocks:mu + clients map[uint64]HandlerClient + // +checklocks:mu sessions map[uint64]Session - rooms map[string]*Room + // +checklocks:ru + rooms map[string]*Room - roomSessions RoomSessions - roomPing *RoomPing + roomSessions RoomSessions + roomPing *RoomPing + // +checklocks:mu virtualSessions map[PublicSessionId]uint64 decodeCaches []*LruCache[*SessionIdData] @@ -179,12 +183,18 @@ type Hub struct { allowSubscribeAnyStream bool - expiredSessions map[Session]time.Time - anonymousSessions map[*ClientSession]time.Time + // +checklocks:mu + expiredSessions map[Session]time.Time + // +checklocks:mu + anonymousSessions map[*ClientSession]time.Time + // +checklocks:mu expectHelloClients map[HandlerClient]time.Time - dialoutSessions map[*ClientSession]bool - remoteSessions map[*RemoteSession]bool - federatedSessions map[*ClientSession]bool + // +checklocks:mu + dialoutSessions map[*ClientSession]bool + // +checklocks:mu + remoteSessions map[*RemoteSession]bool + // +checklocks:mu + federatedSessions map[*ClientSession]bool backendTimeout time.Duration backend *BackendClient @@ -748,6 +758,7 @@ func (h *Hub) CreateProxyToken(publisherId string) (string, error) { return proxy.createToken(publisherId) } +// +checklocks:h.mu func (h *Hub) checkExpiredSessions(now time.Time) { for session, expires := range h.expiredSessions { if now.After(expires) { @@ -761,6 +772,7 @@ func (h *Hub) checkExpiredSessions(now time.Time) { } } +// +checklocks:h.mu func (h *Hub) checkAnonymousSessions(now time.Time) { for session, timeout := range h.anonymousSessions { if now.After(timeout) { @@ -775,6 +787,7 @@ func (h *Hub) checkAnonymousSessions(now time.Time) { } } +// +checklocks:h.mu func (h *Hub) checkInitialHello(now time.Time) { for client, timeout := range h.expectHelloClients { if now.After(timeout) { @@ -788,10 +801,11 @@ func (h *Hub) checkInitialHello(now time.Time) { func (h *Hub) performHousekeeping(now time.Time) { h.mu.Lock() + defer h.mu.Unlock() + h.checkExpiredSessions(now) h.checkAnonymousSessions(now) h.checkInitialHello(now) - h.mu.Unlock() } func (h *Hub) removeSession(session Session) (removed bool) { @@ -820,6 +834,7 @@ func (h *Hub) removeSession(session Session) (removed bool) { return } +// +checklocksread:h.mu func (h *Hub) hasSessionsLocked(withInternal bool) bool { if withInternal { return len(h.sessions) > 0 @@ -841,6 +856,7 @@ func (h *Hub) startWaitAnonymousSessionRoom(session *ClientSession) { h.startWaitAnonymousSessionRoomLocked(session) } +// +checklocks:h.mu func (h *Hub) startWaitAnonymousSessionRoomLocked(session *ClientSession) { if session.ClientType() == HelloClientTypeInternal { // Internal clients don't need to join a room. @@ -1629,7 +1645,7 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo } else { response.Room = &RoomServerMessage{ RoomId: room.id, - Properties: room.properties, + Properties: room.Properties(), } } return session.SendMessage(response) @@ -1764,7 +1780,7 @@ func (h *Hub) processRoom(sess Session, message *ClientMessage) { NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{ Room: &RoomServerMessage{ RoomId: room.id, - Properties: room.properties, + Properties: room.Properties(), }, }), )) @@ -1907,7 +1923,15 @@ func (h *Hub) removeRoom(room *Room) { h.roomPing.DeleteRoom(room.Id()) } -func (h *Hub) createRoom(id string, properties json.RawMessage, backend *Backend) (*Room, error) { +func (h *Hub) CreateRoom(id string, properties json.RawMessage, backend *Backend) (*Room, error) { + h.ru.Lock() + defer h.ru.Unlock() + + return h.createRoomLocked(id, properties, backend) +} + +// +checklocks:h.ru +func (h *Hub) createRoomLocked(id string, properties json.RawMessage, backend *Backend) (*Room, error) { // Note the write lock must be held. room, err := NewRoom(id, properties, h, h.events, backend) if err != nil { @@ -1947,7 +1971,7 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro r, found := h.rooms[internalRoomId] if !found { var err error - if r, err = h.createRoom(roomId, room.Room.Properties, session.Backend()); err != nil { + if r, err = h.createRoomLocked(roomId, room.Room.Properties, session.Backend()); err != nil { h.ru.Unlock() session.SendMessage(message.NewWrappedErrorServerMessage(err)) // The session (implicitly) left the room due to an error. diff --git a/hub_stats_prometheus.go b/hub_stats_prometheus.go index 0c793d7..2e847c9 100644 --- a/hub_stats_prometheus.go +++ b/hub_stats_prometheus.go @@ -26,7 +26,7 @@ import ( ) var ( - statsHubRoomsCurrent = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + statsHubRoomsCurrent = prometheus.NewGaugeVec(prometheus.GaugeOpts{ // +checklocksignore: Global readonly variable. Namespace: "signaling", Subsystem: "hub", Name: "rooms", diff --git a/hub_test.go b/hub_test.go index c6f7425..3b9d28e 100644 --- a/hub_test.go +++ b/hub_test.go @@ -466,6 +466,7 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re var ( sessionRequestHander struct { sync.Mutex + // +checklocks:Mutex handlers map[*testing.T]func(*BackendClientSessionRequest) } ) @@ -2843,8 +2844,8 @@ func TestInitialRoomPermissions(t *testing.T) { session := hub.GetSessionByPublicId(hello.Hello.SessionId).(*ClientSession) require.NotNil(session, "Session %s does not exist", hello.Hello.SessionId) - assert.True(session.HasPermission(PERMISSION_MAY_PUBLISH_AUDIO), "Session %s should have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_AUDIO, session.permissions) - assert.False(session.HasPermission(PERMISSION_MAY_PUBLISH_VIDEO), "Session %s should not have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_VIDEO, session.permissions) + assert.True(session.HasPermission(PERMISSION_MAY_PUBLISH_AUDIO), "Session %s should have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_AUDIO, session.GetPermissions()) + assert.False(session.HasPermission(PERMISSION_MAY_PUBLISH_VIDEO), "Session %s should not have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_VIDEO, session.GetPermissions()) } func TestJoinRoomSwitchClient(t *testing.T) { diff --git a/janus_client.go b/janus_client.go index 3aa8f6f..8686dea 100644 --- a/janus_client.go +++ b/janus_client.go @@ -238,15 +238,18 @@ type JanusGateway struct { listener GatewayListener // Sessions is a map of the currently active sessions to the gateway. + // +checklocks:Mutex Sessions map[uint64]*JanusSession // Access to the Sessions map should be synchronized with the Gateway.Lock() // and Gateway.Unlock() methods provided by the embedded sync.Mutex. sync.Mutex + // +checklocks:writeMu conn *websocket.Conn nextTransaction atomic.Uint64 - transactions map[uint64]*transaction + // +checklocks:Mutex + transactions map[uint64]*transaction closer *Closer @@ -592,6 +595,7 @@ type JanusSession struct { Id uint64 // Handles is a map of plugin handles within this session + // +checklocks:Mutex Handles map[uint64]*JanusHandle // Access to the Handles map should be synchronized with the Session.Lock() diff --git a/lru.go b/lru.go index 6bf8bbf..781e8be 100644 --- a/lru.go +++ b/lru.go @@ -32,10 +32,12 @@ type cacheEntry[T any] struct { } type LruCache[T any] struct { - size int - mu sync.Mutex + size int // +checklocksignore: Only written to from constructor. + mu sync.Mutex + // +checklocks:mu entries *list.List - data map[string]*list.Element + // +checklocks:mu + data map[string]*list.Element } func NewLruCache[T any](size int) *LruCache[T] { @@ -88,6 +90,7 @@ func (c *LruCache[T]) Remove(key string) { c.mu.Unlock() } +// +checklocks:c.mu func (c *LruCache[T]) removeOldestLocked() { v := c.entries.Back() if v != nil { @@ -101,6 +104,7 @@ func (c *LruCache[T]) RemoveOldest() { c.mu.Unlock() } +// +checklocks:c.mu func (c *LruCache[T]) removeElement(e *list.Element) { c.entries.Remove(e) entry := e.Value.(*cacheEntry[T]) diff --git a/mcu_janus.go b/mcu_janus.go index df68df0..72c8472 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -221,13 +221,16 @@ type mcuJanus struct { closeChan chan struct{} muClients sync.Mutex - clients map[clientInterface]bool - clientId atomic.Uint64 + // +checklocks:muClients + clients map[clientInterface]bool + clientId atomic.Uint64 + // +checklocks:mu publishers map[StreamId]*mcuJanusPublisher publisherCreated Notifier publisherConnected Notifier - remotePublishers map[StreamId]*mcuJanusRemotePublisher + // +checklocks:mu + remotePublishers map[StreamId]*mcuJanusRemotePublisher reconnectTimer *time.Timer reconnectInterval time.Duration @@ -684,8 +687,6 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id Pu streamType: streamType, maxBitrate: maxBitrate, - handle: handle, - handleId: handle.Id, closeChan: make(chan struct{}, 1), deferred: make(chan func(), 64), }, @@ -693,6 +694,8 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id Pu id: id, settings: settings, } + client.handle.Store(handle) + client.handleId.Store(handle.Id) client.mcuJanusClient.handleEvent = client.handleEvent client.mcuJanusClient.handleHangup = client.handleHangup client.mcuJanusClient.handleDetached = client.handleDetached @@ -701,7 +704,7 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id Pu client.mcuJanusClient.handleMedia = client.handleMedia m.registerClient(client) - log.Printf("Publisher %s is using handle %d", client.id, client.handleId) + log.Printf("Publisher %s is using handle %d", client.id, handle.Id) go client.run(handle, client.closeChan) m.mu.Lock() m.publishers[getStreamId(id, streamType)] = client @@ -781,13 +784,13 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ streamType: streamType, maxBitrate: pub.MaxBitrate(), - handle: handle, - handleId: handle.Id, closeChan: make(chan struct{}, 1), deferred: make(chan func(), 64), }, publisher: publisher, } + client.handle.Store(handle) + client.handleId.Store(handle.Id) client.mcuJanusClient.handleEvent = client.handleEvent client.mcuJanusClient.handleHangup = client.handleHangup client.mcuJanusClient.handleDetached = client.handleDetached @@ -865,8 +868,6 @@ func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller Re streamType: streamType, maxBitrate: maxBitrate, - handle: handle, - handleId: handle.Id, closeChan: make(chan struct{}, 1), deferred: make(chan func(), 64), }, @@ -881,6 +882,8 @@ func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller Re port: int(port), rtcpPort: int(rtcp_port), } + pub.handle.Store(handle) + pub.handleId.Store(handle.Id) pub.mcuJanusClient.handleEvent = pub.handleEvent pub.mcuJanusClient.handleHangup = pub.handleHangup pub.mcuJanusClient.handleDetached = pub.handleDetached @@ -946,8 +949,6 @@ func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener streamType: publisher.StreamType(), maxBitrate: pub.MaxBitrate(), - handle: handle, - handleId: handle.Id, closeChan: make(chan struct{}, 1), deferred: make(chan func(), 64), }, @@ -956,6 +957,8 @@ func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener } client.remote.Store(pub) pub.addRef() + client.handle.Store(handle) + client.handleId.Store(handle.Id) client.mcuJanusClient.handleEvent = client.handleEvent client.mcuJanusClient.handleHangup = client.handleHangup client.mcuJanusClient.handleDetached = client.handleDetached diff --git a/mcu_janus_client.go b/mcu_janus_client.go index 2165509..88d4688 100644 --- a/mcu_janus_client.go +++ b/mcu_janus_client.go @@ -27,6 +27,7 @@ import ( "reflect" "strconv" "sync" + "sync/atomic" "github.com/notedit/janus-go" @@ -45,8 +46,8 @@ type mcuJanusClient struct { streamType StreamType maxBitrate int - handle *JanusHandle - handleId uint64 + handle atomic.Pointer[JanusHandle] + handleId atomic.Uint64 closeChan chan struct{} deferred chan func() @@ -81,8 +82,7 @@ func (c *mcuJanusClient) SendMessage(ctx context.Context, message *MessageClient } func (c *mcuJanusClient) closeClient(ctx context.Context) bool { - if handle := c.handle; handle != nil { - c.handle = nil + if handle := c.handle.Swap(nil); handle != nil { close(c.closeChan) if _, err := handle.Detach(ctx); err != nil { if e, ok := err.(*janus.ErrorMsg); !ok || e.Err.Code != JANUS_ERROR_HANDLE_NOT_FOUND { @@ -127,7 +127,7 @@ loop: } func (c *mcuJanusClient) sendOffer(ctx context.Context, offer api.StringMap, callback func(error, api.StringMap)) { - handle := c.handle + handle := c.handle.Load() if handle == nil { callback(ErrNotConnected, nil) return @@ -149,7 +149,7 @@ func (c *mcuJanusClient) sendOffer(ctx context.Context, offer api.StringMap, cal } func (c *mcuJanusClient) sendAnswer(ctx context.Context, answer api.StringMap, callback func(error, api.StringMap)) { - handle := c.handle + handle := c.handle.Load() if handle == nil { callback(ErrNotConnected, nil) return @@ -169,7 +169,7 @@ func (c *mcuJanusClient) sendAnswer(ctx context.Context, answer api.StringMap, c } func (c *mcuJanusClient) sendCandidate(ctx context.Context, candidate any, callback func(error, api.StringMap)) { - handle := c.handle + handle := c.handle.Load() if handle == nil { callback(ErrNotConnected, nil) return @@ -191,7 +191,7 @@ func (c *mcuJanusClient) handleTrickle(event *TrickleMsg) { } func (c *mcuJanusClient) selectStream(ctx context.Context, stream *streamSelection, callback func(error, api.StringMap)) { - handle := c.handle + handle := c.handle.Load() if handle == nil { callback(ErrNotConnected, nil) return diff --git a/mcu_janus_publisher.go b/mcu_janus_publisher.go index c4b7e7f..c0135f2 100644 --- a/mcu_janus_publisher.go +++ b/mcu_janus_publisher.go @@ -67,38 +67,38 @@ func (p *mcuJanusPublisher) handleEvent(event *janus.EventMsg) { ctx := context.TODO() switch videoroom { case "destroyed": - log.Printf("Publisher %d: associated room has been destroyed, closing", p.handleId) + log.Printf("Publisher %d: associated room has been destroyed, closing", p.handleId.Load()) go p.Close(ctx) case "slow_link": // Ignore, processed through "handleSlowLink" in the general events. default: - log.Printf("Unsupported videoroom publisher event in %d: %+v", p.handleId, event) + log.Printf("Unsupported videoroom publisher event in %d: %+v", p.handleId.Load(), event) } } else { - log.Printf("Unsupported publisher event in %d: %+v", p.handleId, event) + log.Printf("Unsupported publisher event in %d: %+v", p.handleId.Load(), event) } } func (p *mcuJanusPublisher) handleHangup(event *janus.HangupMsg) { - log.Printf("Publisher %d received hangup (%s), closing", p.handleId, event.Reason) + log.Printf("Publisher %d received hangup (%s), closing", p.handleId.Load(), event.Reason) go p.Close(context.Background()) } func (p *mcuJanusPublisher) handleDetached(event *janus.DetachedMsg) { - log.Printf("Publisher %d received detached, closing", p.handleId) + log.Printf("Publisher %d received detached, closing", p.handleId.Load()) go p.Close(context.Background()) } func (p *mcuJanusPublisher) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Publisher %d received connected", p.handleId) + log.Printf("Publisher %d received connected", p.handleId.Load()) p.mcu.publisherConnected.Notify(string(getStreamId(p.id, p.streamType))) } func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) { if event.Uplink { - log.Printf("Publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) + log.Printf("Publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } else { - log.Printf("Publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) + log.Printf("Publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } } @@ -129,18 +129,22 @@ func (p *mcuJanusPublisher) NotifyReconnected() { return } - p.handle = handle - p.handleId = handle.Id + if prev := p.handle.Swap(handle); prev != nil { + if _, err := prev.Detach(context.Background()); err != nil { + log.Printf("Error detaching old publisher handle %d: %s", prev.Id, err) + } + } + p.handleId.Store(handle.Id) p.session = session p.roomId = roomId - log.Printf("Publisher %s reconnected on handle %d", p.id, p.handleId) + log.Printf("Publisher %s reconnected on handle %d", p.id, p.handleId.Load()) } func (p *mcuJanusPublisher) Close(ctx context.Context) { notify := false p.mu.Lock() - if handle := p.handle; handle != nil && p.roomId != 0 { + if handle := p.handle.Load(); handle != nil && p.roomId != 0 { destroy_msg := api.StringMap{ "request": "destroy", "room": p.roomId, @@ -399,6 +403,11 @@ func getPublisherRemoteId(id PublicSessionId, remoteId PublicSessionId, hostname } func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId PublicSessionId, hostname string, port int, rtcpPort int) error { + handle := p.handle.Load() + if handle == nil { + return ErrNotConnected + } + msg := api.StringMap{ "request": "publish_remotely", "room": p.roomId, @@ -408,7 +417,7 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId PublicSe "port": port, "rtcp_port": rtcpPort, } - response, err := p.handle.Request(ctx, msg) + response, err := handle.Request(ctx, msg) if err != nil { return err } @@ -436,13 +445,18 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId PublicSe } func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId PublicSessionId, hostname string, port int, rtcpPort int) error { + handle := p.handle.Load() + if handle == nil { + return ErrNotConnected + } + msg := api.StringMap{ "request": "unpublish_remotely", "room": p.roomId, "publisher_id": streamTypeUserIds[p.streamType], "remote_id": getPublisherRemoteId(p.id, remoteId, hostname, port, rtcpPort), } - response, err := p.handle.Request(ctx, msg) + response, err := handle.Request(ctx, msg) if err != nil { return err } diff --git a/mcu_janus_publisher_test.go b/mcu_janus_publisher_test.go index 37aa890..98ae23b 100644 --- a/mcu_janus_publisher_test.go +++ b/mcu_janus_publisher_test.go @@ -22,9 +22,15 @@ package signaling import ( + "context" + "sync/atomic" "testing" + "github.com/notedit/janus-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/strukturag/nextcloud-spreed-signaling/api" ) func TestGetFmtpValueH264(t *testing.T) { @@ -94,3 +100,78 @@ func TestGetFmtpValueVP9(t *testing.T) { } } } + +func TestJanusPublisherRemote(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + require := require.New(t) + assert := assert.New(t) + + var remotePublishId atomic.Value + + remoteId := PublicSessionId("the-remote-id") + hostname := "remote.server" + port := 12345 + rtcpPort := 23456 + + mcu, gateway := newMcuJanusForTesting(t) + gateway.registerHandlers(map[string]TestJanusHandler{ + "publish_remotely": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + if value, found := api.GetStringMapString[string](body, "host"); assert.True(found) { + assert.Equal(hostname, value) + } + if value, found := api.GetStringMapEntry[float64](body, "port"); assert.True(found) { + assert.EqualValues(port, value) + } + if value, found := api.GetStringMapEntry[float64](body, "rtcp_port"); assert.True(found) { + assert.EqualValues(rtcpPort, value) + } + if value, found := api.GetStringMapString[string](body, "remote_id"); assert.True(found) { + prev := remotePublishId.Swap(value) + assert.Nil(prev, "should not have previous value") + } + + return &janus.SuccessMsg{ + Data: janus.SuccessData{ + ID: 1, + }, + }, nil + }, + "unpublish_remotely": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + if value, found := api.GetStringMapString[string](body, "remote_id"); assert.True(found) { + if prev := remotePublishId.Load(); assert.NotNil(prev, "should have previous value") { + assert.Equal(prev, value) + } + } + return &janus.SuccessMsg{ + Data: janus.SuccessData{ + ID: 1, + }, + }, nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := PublicSessionId("publisher-id") + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", StreamTypeVideo, settings1, initiator1) + require.NoError(err) + defer pub.Close(context.Background()) + + require.Implements((*McuRemoteAwarePublisher)(nil), pub) + remotePub, _ := pub.(McuRemoteAwarePublisher) + + if assert.NoError(remotePub.PublishRemote(ctx, remoteId, hostname, port, rtcpPort)) { + assert.NoError(remotePub.UnpublishRemote(ctx, remoteId, hostname, port, rtcpPort)) + } +} diff --git a/mcu_janus_remote_publisher.go b/mcu_janus_remote_publisher.go index f2bafe6..078adc5 100644 --- a/mcu_janus_remote_publisher.go +++ b/mcu_janus_remote_publisher.go @@ -63,38 +63,38 @@ func (p *mcuJanusRemotePublisher) handleEvent(event *janus.EventMsg) { ctx := context.TODO() switch videoroom { case "destroyed": - log.Printf("Remote publisher %d: associated room has been destroyed, closing", p.handleId) + log.Printf("Remote publisher %d: associated room has been destroyed, closing", p.handleId.Load()) go p.Close(ctx) case "slow_link": // Ignore, processed through "handleSlowLink" in the general events. default: - log.Printf("Unsupported videoroom remote publisher event in %d: %+v", p.handleId, event) + log.Printf("Unsupported videoroom remote publisher event in %d: %+v", p.handleId.Load(), event) } } else { - log.Printf("Unsupported remote publisher event in %d: %+v", p.handleId, event) + log.Printf("Unsupported remote publisher event in %d: %+v", p.handleId.Load(), event) } } func (p *mcuJanusRemotePublisher) handleHangup(event *janus.HangupMsg) { - log.Printf("Remote publisher %d received hangup (%s), closing", p.handleId, event.Reason) + log.Printf("Remote publisher %d received hangup (%s), closing", p.handleId.Load(), event.Reason) go p.Close(context.Background()) } func (p *mcuJanusRemotePublisher) handleDetached(event *janus.DetachedMsg) { - log.Printf("Remote publisher %d received detached, closing", p.handleId) + log.Printf("Remote publisher %d received detached, closing", p.handleId.Load()) go p.Close(context.Background()) } func (p *mcuJanusRemotePublisher) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Remote publisher %d received connected", p.handleId) + log.Printf("Remote publisher %d received connected", p.handleId.Load()) p.mcu.publisherConnected.Notify(string(getStreamId(p.id, p.streamType))) } func (p *mcuJanusRemotePublisher) handleSlowLink(event *janus.SlowLinkMsg) { if event.Uplink { - log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) + log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } else { - log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) + log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } } @@ -107,12 +107,16 @@ func (p *mcuJanusRemotePublisher) NotifyReconnected() { return } - p.handle = handle - p.handleId = handle.Id + if prev := p.handle.Swap(handle); prev != nil { + if _, err := prev.Detach(context.Background()); err != nil { + log.Printf("Error detaching old remote publisher handle %d: %s", prev.Id, err) + } + } + p.handleId.Store(handle.Id) p.session = session p.roomId = roomId - log.Printf("Remote publisher %s reconnected on handle %d", p.id, p.handleId) + log.Printf("Remote publisher %s reconnected on handle %d", p.id, p.handleId.Load()) } func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { @@ -125,8 +129,10 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { } p.mu.Lock() - if handle := p.handle; handle != nil { - response, err := p.handle.Request(ctx, api.StringMap{ + defer p.mu.Unlock() + + if handle := p.handle.Load(); handle != nil { + response, err := handle.Request(ctx, api.StringMap{ "request": "remove_remote_publisher", "room": p.roomId, "id": streamTypeUserIds[p.streamType], @@ -154,5 +160,4 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { } p.closeClient(ctx) - p.mu.Unlock() } diff --git a/mcu_janus_remote_subscriber.go b/mcu_janus_remote_subscriber.go index 8f7fe5e..356bb65 100644 --- a/mcu_janus_remote_subscriber.go +++ b/mcu_janus_remote_subscriber.go @@ -41,7 +41,7 @@ func (p *mcuJanusRemoteSubscriber) handleEvent(event *janus.EventMsg) { ctx := context.TODO() switch videoroom { case "destroyed": - log.Printf("Remote subscriber %d: associated room has been destroyed, closing", p.handleId) + log.Printf("Remote subscriber %d: associated room has been destroyed, closing", p.handleId.Load()) go p.Close(ctx) case "event": // Handle renegotiations, but ignore other events like selected @@ -53,33 +53,33 @@ func (p *mcuJanusRemoteSubscriber) handleEvent(event *janus.EventMsg) { case "slow_link": // Ignore, processed through "handleSlowLink" in the general events. default: - log.Printf("Unsupported videoroom event %s for remote subscriber %d: %+v", videoroom, p.handleId, event) + log.Printf("Unsupported videoroom event %s for remote subscriber %d: %+v", videoroom, p.handleId.Load(), event) } } else { - log.Printf("Unsupported event for remote subscriber %d: %+v", p.handleId, event) + log.Printf("Unsupported event for remote subscriber %d: %+v", p.handleId.Load(), event) } } func (p *mcuJanusRemoteSubscriber) handleHangup(event *janus.HangupMsg) { - log.Printf("Remote subscriber %d received hangup (%s), closing", p.handleId, event.Reason) + log.Printf("Remote subscriber %d received hangup (%s), closing", p.handleId.Load(), event.Reason) go p.Close(context.Background()) } func (p *mcuJanusRemoteSubscriber) handleDetached(event *janus.DetachedMsg) { - log.Printf("Remote subscriber %d received detached, closing", p.handleId) + log.Printf("Remote subscriber %d received detached, closing", p.handleId.Load()) go p.Close(context.Background()) } func (p *mcuJanusRemoteSubscriber) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Remote subscriber %d received connected", p.handleId) + log.Printf("Remote subscriber %d received connected", p.handleId.Load()) p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType) } func (p *mcuJanusRemoteSubscriber) handleSlowLink(event *janus.SlowLinkMsg) { if event.Uplink { - log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) + log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } else { - log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) + log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } } @@ -98,12 +98,16 @@ func (p *mcuJanusRemoteSubscriber) NotifyReconnected() { return } - p.handle = handle - p.handleId = handle.Id + if prev := p.handle.Swap(handle); prev != nil { + if _, err := prev.Detach(context.Background()); err != nil { + log.Printf("Error detaching old remote subscriber handle %d: %s", prev.Id, err) + } + } + p.handleId.Store(handle.Id) p.roomId = pub.roomId p.sid = strconv.FormatUint(handle.Id, 10) p.listener.SubscriberSidUpdated(p) - log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId) + log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId.Load()) } func (p *mcuJanusRemoteSubscriber) Close(ctx context.Context) { diff --git a/mcu_janus_subscriber.go b/mcu_janus_subscriber.go index 9699cef..dfb21fe 100644 --- a/mcu_janus_subscriber.go +++ b/mcu_janus_subscriber.go @@ -47,7 +47,7 @@ func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) { ctx := context.TODO() switch videoroom { case "destroyed": - log.Printf("Subscriber %d: associated room has been destroyed, closing", p.handleId) + log.Printf("Subscriber %d: associated room has been destroyed, closing", p.handleId.Load()) go p.Close(ctx) case "updated": streams, ok := getPluginValue(event.Plugindata, pluginVideoRoom, "streams").([]any) @@ -64,45 +64,48 @@ func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) { } } - log.Printf("Subscriber %d: received updated event with no active media streams, closing", p.handleId) + log.Printf("Subscriber %d: received updated event with no active media streams, closing", p.handleId.Load()) go p.Close(ctx) case "event": // Handle renegotiations, but ignore other events like selected // substream / temporal layer. if getPluginStringValue(event.Plugindata, pluginVideoRoom, "configured") == "ok" && event.Jsep != nil && event.Jsep["type"] == "offer" && event.Jsep["sdp"] != nil { + log.Printf("Subscriber %d: received updated offer", p.handleId.Load()) p.listener.OnUpdateOffer(p, event.Jsep) + } else { + log.Printf("Subscriber %d: received unsupported event %+v", p.handleId.Load(), event) } case "slow_link": // Ignore, processed through "handleSlowLink" in the general events. default: - log.Printf("Unsupported videoroom event %s for subscriber %d: %+v", videoroom, p.handleId, event) + log.Printf("Unsupported videoroom event %s for subscriber %d: %+v", videoroom, p.handleId.Load(), event) } } else { - log.Printf("Unsupported event for subscriber %d: %+v", p.handleId, event) + log.Printf("Unsupported event for subscriber %d: %+v", p.handleId.Load(), event) } } func (p *mcuJanusSubscriber) handleHangup(event *janus.HangupMsg) { - log.Printf("Subscriber %d received hangup (%s), closing", p.handleId, event.Reason) + log.Printf("Subscriber %d received hangup (%s), closing", p.handleId.Load(), event.Reason) go p.Close(context.Background()) } func (p *mcuJanusSubscriber) handleDetached(event *janus.DetachedMsg) { - log.Printf("Subscriber %d received detached, closing", p.handleId) + log.Printf("Subscriber %d received detached, closing", p.handleId.Load()) go p.Close(context.Background()) } func (p *mcuJanusSubscriber) handleConnected(event *janus.WebRTCUpMsg) { - log.Printf("Subscriber %d received connected", p.handleId) + log.Printf("Subscriber %d received connected", p.handleId.Load()) p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType) } func (p *mcuJanusSubscriber) handleSlowLink(event *janus.SlowLinkMsg) { if event.Uplink { - log.Printf("Subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost) + log.Printf("Subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } else { - log.Printf("Subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost) + log.Printf("Subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost) } } @@ -121,12 +124,16 @@ func (p *mcuJanusSubscriber) NotifyReconnected() { return } - p.handle = handle - p.handleId = handle.Id + if prev := p.handle.Swap(handle); prev != nil { + if _, err := prev.Detach(context.Background()); err != nil { + log.Printf("Error detaching old subscriber handle %d: %s", prev.Id, err) + } + } + p.handleId.Store(handle.Id) p.roomId = pub.roomId p.sid = strconv.FormatUint(handle.Id, 10) p.listener.SubscriberSidUpdated(p) - log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId) + log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId.Load()) } func (p *mcuJanusSubscriber) closeClient(ctx context.Context) bool { @@ -152,7 +159,7 @@ func (p *mcuJanusSubscriber) Close(ctx context.Context) { } func (p *mcuJanusSubscriber) joinRoom(ctx context.Context, stream *streamSelection, callback func(error, api.StringMap)) { - handle := p.handle + handle := p.handle.Load() if handle == nil { callback(ErrNotConnected, nil) return @@ -210,15 +217,19 @@ retry: return } - p.handle = handle - p.handleId = handle.Id + if prev := p.handle.Swap(handle); prev != nil { + if _, err := prev.Detach(context.Background()); err != nil { + log.Printf("Error detaching old subscriber handle %d: %s", prev.Id, err) + } + } + p.handleId.Store(handle.Id) p.roomId = pub.roomId p.sid = strconv.FormatUint(handle.Id, 10) p.listener.SubscriberSidUpdated(p) p.closeChan = make(chan struct{}, 1) statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Inc() - go p.run(p.handle, p.closeChan) - log.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId) + go p.run(handle, p.closeChan) + log.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId.Load()) goto retry case JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM: fallthrough @@ -258,7 +269,7 @@ retry: } func (p *mcuJanusSubscriber) update(ctx context.Context, stream *streamSelection, callback func(error, api.StringMap)) { - handle := p.handle + handle := p.handle.Load() if handle == nil { callback(ErrNotConnected, nil) return diff --git a/mcu_janus_test.go b/mcu_janus_test.go index 1bfa955..641a0f5 100644 --- a/mcu_janus_test.go +++ b/mcu_janus_test.go @@ -58,19 +58,27 @@ type TestJanusGateway struct { sid atomic.Uint64 tid atomic.Uint64 - hid atomic.Uint64 - rid atomic.Uint64 + hid atomic.Uint64 // +checklocksignore: Atomic + rid atomic.Uint64 // +checklocksignore: Atomic mu sync.Mutex - sessions map[uint64]*JanusSession + // +checklocks:mu + sessions map[uint64]*JanusSession + // +checklocks:mu transactions map[uint64]*transaction - handles map[uint64]*TestJanusHandle - rooms map[uint64]*TestJanusRoom - handlers map[string]TestJanusHandler + // +checklocks:mu + handles map[uint64]*TestJanusHandle + // +checklocks:mu + rooms map[uint64]*TestJanusRoom + // +checklocks:mu + handlers map[string]TestJanusHandler - attachCount atomic.Int32 - joinCount atomic.Int32 + // +checklocks:mu + attachCount int + // +checklocks:mu + joinCount int + // +checklocks:mu handleRooms map[*TestJanusHandle]*TestJanusRoom } @@ -142,6 +150,19 @@ func (g *TestJanusGateway) Close() error { return nil } +func (g *TestJanusGateway) simulateEvent(delay time.Duration, session *JanusSession, handle *TestJanusHandle, event any) { + go func() { + time.Sleep(delay) + session.Lock() + h, found := session.Handles[handle.id] + session.Unlock() + if found { + h.Events <- event + } + }() +} + +// +checklocks:g.mu func (g *TestJanusGateway) processMessage(session *JanusSession, handle *TestJanusHandle, body api.StringMap, jsep api.StringMap) any { request := body["request"].(string) switch request { @@ -165,15 +186,18 @@ func (g *TestJanusGateway) processMessage(session *JanusSession, handle *TestJan error_code := JANUS_OK if body["ptype"] == "subscriber" { if strings.Contains(g.t.Name(), "NoSuchRoom") { - if g.joinCount.Add(1) == 1 { + g.joinCount++ + if g.joinCount == 1 { error_code = JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM } } else if strings.Contains(g.t.Name(), "AlreadyJoined") { - if g.joinCount.Add(1) == 1 { + g.joinCount++ + if g.joinCount == 1 { error_code = JANUS_VIDEOROOM_ERROR_ALREADY_JOINED } } else if strings.Contains(g.t.Name(), "SubscriberTimeout") { - if g.joinCount.Add(1) == 1 { + g.joinCount++ + if g.joinCount == 1 { error_code = JANUS_VIDEOROOM_ERROR_NO_SUCH_FEED } } @@ -236,6 +260,66 @@ func (g *TestJanusGateway) processMessage(session *JanusSession, handle *TestJan } sdp := publisher.sdp.Load() + + // Simulate "connected" event for subscriber. + g.simulateEvent(15*time.Millisecond, session, handle, &janus.WebRTCUpMsg{ + Session: session.Id, + Handle: handle.id, + }) + + if strings.Contains(g.t.Name(), "CloseEmptyStreams") { + // Simulate stream update event with no active streams. + g.simulateEvent(20*time.Millisecond, session, handle, &janus.EventMsg{ + Session: session.Id, + Handle: handle.id, + Plugindata: janus.PluginData{ + Plugin: pluginVideoRoom, + Data: api.StringMap{ + "videoroom": "updated", + "streams": []any{ + api.StringMap{ + "type": "audio", + "active": false, + }, + }, + }, + }, + }) + } + + if strings.Contains(g.t.Name(), "SubscriberRoomDestroyed") { + // Simulate event that subscriber room has been destroyed. + g.simulateEvent(20*time.Millisecond, session, handle, &janus.EventMsg{ + Session: session.Id, + Handle: handle.id, + Plugindata: janus.PluginData{ + Plugin: pluginVideoRoom, + Data: api.StringMap{ + "videoroom": "destroyed", + }, + }, + }) + } + + if strings.Contains(g.t.Name(), "SubscriberUpdateOffer") { + // Simulate event that subscriber receives new offer. + g.simulateEvent(20*time.Millisecond, session, handle, &janus.EventMsg{ + Session: session.Id, + Handle: handle.id, + Plugindata: janus.PluginData{ + Plugin: pluginVideoRoom, + Data: api.StringMap{ + "videoroom": "event", + "configured": "ok", + }, + }, + Jsep: map[string]any{ + "type": "offer", + "sdp": MockSdpOfferAudioOnly, + }, + }) + } + return &janus.EventMsg{ Jsep: api.StringMap{ "type": "offer", @@ -305,23 +389,13 @@ func (g *TestJanusGateway) processMessage(session *JanusSession, handle *TestJan case "configure": if sdp, found := jsep["sdp"]; found { handle.sdp.Store(sdp.(string)) - // Simulate "connected" event. - go func() { - if strings.Contains(g.t.Name(), "SubscriberTimeout") { - return - } - - time.Sleep(10 * time.Millisecond) - session.Lock() - h, found := session.Handles[handle.id] - session.Unlock() - if found { - h.Events <- &janus.WebRTCUpMsg{ - Session: session.Id, - Handle: h.Id, - } - } - }() + if !strings.Contains(g.t.Name(), "SubscriberTimeout") { + // Simulate "connected" event for publisher. + g.simulateEvent(10*time.Millisecond, session, handle, &janus.WebRTCUpMsg{ + Session: session.Id, + Handle: handle.id, + }) + } } } } @@ -354,7 +428,8 @@ func (g *TestJanusGateway) processRequest(msg api.StringMap) any { switch method { case "attach": if strings.Contains(g.t.Name(), "AlreadyJoinedAttachError") { - if g.attachCount.Add(1) == 4 { + g.attachCount++ + if g.attachCount == 4 { return &janus.ErrorMsg{ Err: janus.ErrorData{ Code: JANUS_ERROR_UNKNOWN, @@ -1581,3 +1656,318 @@ func Test_JanusSubscriberTimeout(t *testing.T) { client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo) } + +func Test_JanusSubscriberCloseEmptyStreams(t *testing.T) { + ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video")) + t.Cleanup(func() { + if !t.Failed() { + checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0) + } + }) + + CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) + + mcu, gateway := newMcuJanusForTesting(t) + gateway.registerHandlers(map[string]TestJanusHandler{ + "configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.id) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + hub, _, _, server := CreateHubForTest(t) + hub.SetMcu(mcu) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") + client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) + + // Join room by id. + roomId := "test-room" + roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + + // Give message processing some time. + time.Sleep(10 * time.Millisecond) + + roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + + WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) + + // Simulate request from the backend that sessions joined the call. + users1 := []api.StringMap{ + { + "sessionId": hello1.Hello.SessionId, + "inCall": 1, + }, + { + "sessionId": hello2.Hello.SessionId, + "inCall": 1, + }, + } + room := hub.getRoom(roomId) + require.NotNil(room, "Could not find room %s", roomId) + room.PublishUsersInCallChanged(users1, users1) + checkReceiveClientEvent(ctx, t, client1, "update", nil) + checkReceiveClientEvent(ctx, t, client2, "update", nil) + + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": MockSdpOfferAudioAndVideo, + }, + })) + + client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo) + + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + })) + + client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo) + + sess2 := hub.GetSessionByPublicId(hello2.Hello.SessionId) + require.NotNil(sess2) + session2 := sess2.(*ClientSession) + + sub := session2.GetSubscriber(hello1.Hello.SessionId, StreamTypeVideo) + require.NotNil(sub) + + subscriber := sub.(*mcuJanusSubscriber) + handle := subscriber.handle.Load() + require.NotNil(handle) + + for ctx.Err() == nil { + if handle = subscriber.handle.Load(); handle == nil { + break + } + + time.Sleep(time.Millisecond) + } + + assert.Nil(handle, "subscriber should have been closed") +} + +func Test_JanusSubscriberRoomDestroyed(t *testing.T) { + ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video")) + t.Cleanup(func() { + if !t.Failed() { + checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0) + } + }) + + CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) + + mcu, gateway := newMcuJanusForTesting(t) + gateway.registerHandlers(map[string]TestJanusHandler{ + "configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.id) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + hub, _, _, server := CreateHubForTest(t) + hub.SetMcu(mcu) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") + client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) + + // Join room by id. + roomId := "test-room" + roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + + // Give message processing some time. + time.Sleep(10 * time.Millisecond) + + roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + + WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) + + // Simulate request from the backend that sessions joined the call. + users1 := []api.StringMap{ + { + "sessionId": hello1.Hello.SessionId, + "inCall": 1, + }, + { + "sessionId": hello2.Hello.SessionId, + "inCall": 1, + }, + } + room := hub.getRoom(roomId) + require.NotNil(room, "Could not find room %s", roomId) + room.PublishUsersInCallChanged(users1, users1) + checkReceiveClientEvent(ctx, t, client1, "update", nil) + checkReceiveClientEvent(ctx, t, client2, "update", nil) + + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": MockSdpOfferAudioAndVideo, + }, + })) + + client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo) + + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + })) + + client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo) + + sess2 := hub.GetSessionByPublicId(hello2.Hello.SessionId) + require.NotNil(sess2) + session2 := sess2.(*ClientSession) + + sub := session2.GetSubscriber(hello1.Hello.SessionId, StreamTypeVideo) + require.NotNil(sub) + + subscriber := sub.(*mcuJanusSubscriber) + handle := subscriber.handle.Load() + require.NotNil(handle) + + for ctx.Err() == nil { + if handle = subscriber.handle.Load(); handle == nil { + break + } + + time.Sleep(time.Millisecond) + } + + assert.Nil(handle, "subscriber should have been closed") +} + +func Test_JanusSubscriberUpdateOffer(t *testing.T) { + ResetStatsValue(t, statsSubscribersCurrent.WithLabelValues("video")) + t.Cleanup(func() { + if !t.Failed() { + checkStatsValue(t, statsSubscribersCurrent.WithLabelValues("video"), 0) + } + }) + + CatchLogForTest(t) + require := require.New(t) + assert := assert.New(t) + + mcu, gateway := newMcuJanusForTesting(t) + gateway.registerHandlers(map[string]TestJanusHandler{ + "configure": func(room *TestJanusRoom, body, jsep api.StringMap) (any, *janus.ErrorMsg) { + assert.EqualValues(1, room.id) + return &janus.EventMsg{ + Jsep: api.StringMap{ + "type": "answer", + "sdp": MockSdpAnswerAudioAndVideo, + }, + }, nil + }, + }) + + hub, _, _, server := CreateHubForTest(t) + hub.SetMcu(mcu) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1, hello1 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"1") + client2, hello2 := NewTestClientWithHello(ctx, t, server, hub, testDefaultUserId+"2") + require.NotEqual(hello1.Hello.SessionId, hello2.Hello.SessionId) + require.NotEqual(hello1.Hello.UserId, hello2.Hello.UserId) + + // Join room by id. + roomId := "test-room" + roomMsg := MustSucceed2(t, client1.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + + // Give message processing some time. + time.Sleep(10 * time.Millisecond) + + roomMsg = MustSucceed2(t, client2.JoinRoom, ctx, roomId) + require.Equal(roomId, roomMsg.Room.RoomId) + + WaitForUsersJoined(ctx, t, client1, hello1, client2, hello2) + + // Simulate request from the backend that sessions joined the call. + users1 := []api.StringMap{ + { + "sessionId": hello1.Hello.SessionId, + "inCall": 1, + }, + { + "sessionId": hello2.Hello.SessionId, + "inCall": 1, + }, + } + room := hub.getRoom(roomId) + require.NotNil(room, "Could not find room %s", roomId) + room.PublishUsersInCallChanged(users1, users1) + checkReceiveClientEvent(ctx, t, client1, "update", nil) + checkReceiveClientEvent(ctx, t, client2, "update", nil) + + require.NoError(client1.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "offer", + RoomType: "video", + Payload: api.StringMap{ + "sdp": MockSdpOfferAudioAndVideo, + }, + })) + + client1.RunUntilAnswer(ctx, MockSdpAnswerAudioAndVideo) + + require.NoError(client2.SendMessage(MessageClientMessageRecipient{ + Type: "session", + SessionId: hello1.Hello.SessionId, + }, MessageClientMessageData{ + Type: "requestoffer", + RoomType: "video", + })) + + client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo) + + // Test MCU will trigger an updated offer. + client2.RunUntilOffer(ctx, MockSdpOfferAudioOnly) +} diff --git a/mcu_proxy.go b/mcu_proxy.go index 5bcf45c..167a2af 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -351,10 +351,11 @@ type mcuProxyConnection struct { closer *Closer closedDone *Closer closed atomic.Bool - conn *websocket.Conn + // +checklocks:mu + conn *websocket.Conn helloProcessed atomic.Bool - connectedSince time.Time + connectedSince atomic.Int64 reconnectTimer *time.Timer reconnectInterval atomic.Int64 shutdownScheduled atomic.Bool @@ -371,15 +372,20 @@ type mcuProxyConnection struct { version atomic.Value features atomic.Value - callbacks map[string]mcuProxyCallback + // +checklocks:mu + callbacks map[string]mcuProxyCallback + // +checklocks:mu deferredCallbacks map[string]mcuProxyCallback publishersLock sync.RWMutex - publishers map[string]*mcuProxyPublisher - publisherIds map[StreamId]PublicSessionId + // +checklocks:publishersLock + publishers map[string]*mcuProxyPublisher + // +checklocks:publishersLock + publisherIds map[StreamId]PublicSessionId subscribersLock sync.RWMutex - subscribers map[string]*mcuProxySubscriber + // +checklocks:subscribersLock + subscribers map[string]*mcuProxySubscriber } func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP, token string) (*mcuProxyConnection, error) { @@ -486,7 +492,10 @@ func (c *mcuProxyConnection) GetStats() *mcuProxyConnectionStats { c.mu.Lock() if c.conn != nil { result.Connected = true - result.Uptime = &c.connectedSince + if since := c.connectedSince.Load(); since != 0 { + t := time.UnixMicro(since) + result.Uptime = &t + } load := c.Load() result.Load = &load shutdown := c.IsShutdownScheduled() @@ -705,6 +714,7 @@ func (c *mcuProxyConnection) close() { if c.conn != nil { c.conn.Close() c.conn = nil + c.connectedSince.Store(0) if c.trackClose.CompareAndSwap(true, false) { statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Dec() } @@ -810,9 +820,9 @@ func (c *mcuProxyConnection) reconnect() { log.Printf("Connected to %s", c) c.closed.Store(false) c.helloProcessed.Store(false) + c.connectedSince.Store(time.Now().UnixMicro()) c.mu.Lock() - c.connectedSince = time.Now() c.conn = conn c.mu.Unlock() @@ -1157,6 +1167,7 @@ func (c *mcuProxyConnection) sendMessage(msg *ProxyClientMessage) error { return c.sendMessageLocked(msg) } +// +checklocks:c.mu func (c *mcuProxyConnection) sendMessageLocked(msg *ProxyClientMessage) error { if proxyDebugMessages { log.Printf("Send message to %s: %+v", c, msg) @@ -1449,16 +1460,19 @@ type mcuProxy struct { tokenKey *rsa.PrivateKey config ProxyConfig - dialer *websocket.Dialer - connections []*mcuProxyConnection + dialer *websocket.Dialer + connectionsMu sync.RWMutex + // +checklocks:connectionsMu + connections []*mcuProxyConnection + // +checklocks:connectionsMu connectionsMap map[string][]*mcuProxyConnection - connectionsMu sync.RWMutex connRequests atomic.Int64 nextSort atomic.Int64 settings McuSettings - mu sync.RWMutex + mu sync.RWMutex + // +checklocks:mu publishers map[StreamId]*mcuProxyConnection publisherWaiters ChannelWaiters @@ -1615,6 +1629,13 @@ func (m *mcuProxy) createToken(subject string) (string, error) { return tokenString, nil } +func (m *mcuProxy) getConnections() []*mcuProxyConnection { + m.connectionsMu.RLock() + defer m.connectionsMu.RUnlock() + + return m.connections +} + func (m *mcuProxy) hasConnections() bool { m.connectionsMu.RLock() defer m.connectionsMu.RUnlock() @@ -1835,7 +1856,10 @@ func (m *mcuProxy) GetServerInfoSfu() *BackendServerInfoSfu { if c.IsConnected() { proxy.Connected = true proxy.Shutdown = internal.MakePtr(c.IsShutdownScheduled()) - proxy.Uptime = &c.connectedSince + if since := c.connectedSince.Load(); since != 0 { + t := time.UnixMicro(since) + proxy.Uptime = &t + } proxy.Version = c.Version() proxy.Features = c.Features() proxy.Country = c.Country() diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index efd87d0..5f3e64e 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -204,7 +204,8 @@ type testProxyServerSubscriber struct { type testProxyServerClient struct { t *testing.T - server *TestProxyServerHandler + server *TestProxyServerHandler + // +checklocks:mu ws *websocket.Conn processMessage proxyServerClientHandler @@ -549,12 +550,15 @@ type TestProxyServerHandler struct { upgrader *websocket.Upgrader country string - mu sync.Mutex - load atomic.Int64 - incoming atomic.Pointer[float64] - outgoing atomic.Pointer[float64] - clients map[PublicSessionId]*testProxyServerClient - publishers map[PublicSessionId]*testProxyServerPublisher + mu sync.Mutex + load atomic.Int64 + incoming atomic.Pointer[float64] + outgoing atomic.Pointer[float64] + // +checklocks:mu + clients map[PublicSessionId]*testProxyServerClient + // +checklocks:mu + publishers map[PublicSessionId]*testProxyServerPublisher + // +checklocks:mu subscribers map[string]*testProxyServerSubscriber wakeupChan chan struct{} @@ -1039,8 +1043,8 @@ func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) { }, }, 0) - if assert.NotNil(mcu.connections[0].ip) { - assert.True(ip1.Equal(mcu.connections[0].ip), "ip addresses differ: expected %s, got %s", ip1.String(), mcu.connections[0].ip.String()) + if connections := mcu.getConnections(); assert.Len(connections, 1) && assert.NotNil(connections[0].ip) { + assert.True(ip1.Equal(connections[0].ip), "ip addresses differ: expected %s, got %s", ip1.String(), connections[0].ip.String()) } dnsMonitor := mcu.config.(*proxyConfigStatic).dnsMonitor @@ -1744,8 +1748,9 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) { } type mockGrpcServerHub struct { - proxy atomic.Pointer[mcuProxy] - sessionsLock sync.Mutex + proxy atomic.Pointer[mcuProxy] + sessionsLock sync.Mutex + // +checklocks:sessionsLock sessionByPublicId map[PublicSessionId]Session } diff --git a/mcu_test.go b/mcu_test.go index 62f6d2a..9f26eb5 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -41,8 +41,10 @@ const ( ) type TestMCU struct { - mu sync.Mutex - publishers map[PublicSessionId]*TestMCUPublisher + mu sync.Mutex + // +checklocks:mu + publishers map[PublicSessionId]*TestMCUPublisher + // +checklocks:mu subscribers map[string]*TestMCUSubscriber } diff --git a/natsclient_loopback.go b/natsclient_loopback.go index 952f27a..6b8f56a 100644 --- a/natsclient_loopback.go +++ b/natsclient_loopback.go @@ -32,10 +32,13 @@ import ( ) type LoopbackNatsClient struct { - mu sync.Mutex + mu sync.Mutex + // +checklocks:mu subscriptions map[string]map[*loopbackNatsSubscription]bool - wakeup sync.Cond + // +checklocks:mu + wakeup sync.Cond + // +checklocks:mu incoming list.List } @@ -65,6 +68,7 @@ func (c *LoopbackNatsClient) processMessages() { } } +// +checklocks:c.mu func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) { subs, found := c.subscriptions[msg.Subject] if !found { diff --git a/notifier.go b/notifier.go index 3466f45..800711d 100644 --- a/notifier.go +++ b/notifier.go @@ -39,7 +39,9 @@ func (w *Waiter) Wait(ctx context.Context) error { type Notifier struct { sync.Mutex - waiters map[string]*Waiter + // +checklocks:Mutex + waiters map[string]*Waiter + // +checklocks:Mutex waiterMap map[string]map[*Waiter]bool } diff --git a/proxy/proxy_remote.go b/proxy/proxy_remote.go index b86dd73..6eaad05 100644 --- a/proxy/proxy_remote.go +++ b/proxy/proxy_remote.go @@ -58,30 +58,38 @@ const ( ) var ( - ErrNotConnected = errors.New("not connected") + ErrNotConnected = errors.New("not connected") // +checklocksignore: Global readonly variable. ) type RemoteConnection struct { - mu sync.Mutex - p *ProxyServer - url *url.URL - conn *websocket.Conn - closer *signaling.Closer - closed atomic.Bool + mu sync.Mutex + p *ProxyServer + url *url.URL + // +checklocks:mu + conn *websocket.Conn + closeCtx context.Context + closeFunc context.CancelFunc // +checklocksignore: Only written to from constructor. tokenId string tokenKey *rsa.PrivateKey tlsConfig *tls.Config + // +checklocks:mu connectedSince time.Time reconnectTimer *time.Timer reconnectInterval atomic.Int64 - msgId atomic.Int64 + msgId atomic.Int64 + // +checklocks:mu helloMsgId string - sessionId signaling.PublicSessionId + // +checklocks:mu + sessionId signaling.PublicSessionId + // +checklocks:mu + helloReceived bool - pendingMessages []*signaling.ProxyClientMessage + // +checklocks:mu + pendingMessages []*signaling.ProxyClientMessage + // +checklocks:mu messageCallbacks map[string]chan *signaling.ProxyServerMessage } @@ -91,10 +99,13 @@ func NewRemoteConnection(p *ProxyServer, proxyUrl string, tokenId string, tokenK return nil, err } + closeCtx, closeFunc := context.WithCancel(context.Background()) + result := &RemoteConnection{ - p: p, - url: u, - closer: signaling.NewCloser(), + p: p, + url: u, + closeCtx: closeCtx, + closeFunc: closeFunc, tokenId: tokenId, tokenKey: tokenKey, @@ -115,6 +126,12 @@ func (c *RemoteConnection) String() string { return c.url.String() } +func (c *RemoteConnection) SessionId() signaling.PublicSessionId { + c.mu.Lock() + defer c.mu.Unlock() + return c.sessionId +} + func (c *RemoteConnection) reconnect() { u, err := c.url.Parse("proxy") if err != nil { @@ -142,7 +159,6 @@ func (c *RemoteConnection) reconnect() { } log.Printf("Connected to %s", c) - c.closed.Store(false) c.mu.Lock() c.connectedSince = time.Now() @@ -151,24 +167,39 @@ func (c *RemoteConnection) reconnect() { c.reconnectInterval.Store(int64(initialReconnectInterval)) - if err := c.sendHello(); err != nil { - log.Printf("Error sending hello request to proxy at %s: %s", c, err) + if !c.sendReconnectHello() || !c.sendPing() { c.scheduleReconnect() return } - if !c.sendPing() { - return - } - go c.readPump(conn) } +func (c *RemoteConnection) sendReconnectHello() bool { + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.sendHello(c.closeCtx); err != nil { + log.Printf("Error sending hello request to proxy at %s: %s", c, err) + return false + } + + return true +} + func (c *RemoteConnection) scheduleReconnect() { - if err := c.sendClose(); err != nil && err != ErrNotConnected { + c.mu.Lock() + defer c.mu.Unlock() + + c.scheduleReconnectLocked() +} + +// +checklocks:c.mu +func (c *RemoteConnection) scheduleReconnectLocked() { + if err := c.sendCloseLocked(); err != nil && err != ErrNotConnected { log.Printf("Could not send close message to %s: %s", c, err) } - c.close() + c.closeLocked() interval := c.reconnectInterval.Load() // Prevent all servers from reconnecting at the same time in case of an @@ -180,7 +211,8 @@ func (c *RemoteConnection) scheduleReconnect() { c.reconnectInterval.Store(interval) } -func (c *RemoteConnection) sendHello() error { +// +checklocks:c.mu +func (c *RemoteConnection) sendHello(ctx context.Context) error { c.helloMsgId = strconv.FormatInt(c.msgId.Add(1), 10) msg := &signaling.ProxyClientMessage{ Id: c.helloMsgId, @@ -200,16 +232,10 @@ func (c *RemoteConnection) sendHello() error { msg.Hello.Token = tokenString } - return c.SendMessage(msg) -} - -func (c *RemoteConnection) sendClose() error { - c.mu.Lock() - defer c.mu.Unlock() - - return c.sendCloseLocked() + return c.sendMessageLocked(ctx, msg) } +// +checklocks:c.mu func (c *RemoteConnection) sendCloseLocked() error { if c.conn == nil { return ErrNotConnected @@ -223,10 +249,17 @@ func (c *RemoteConnection) close() { c.mu.Lock() defer c.mu.Unlock() + c.closeLocked() +} + +// +checklocks:c.mu +func (c *RemoteConnection) closeLocked() { if c.conn != nil { c.conn.Close() c.conn = nil } + c.connectedSince = time.Time{} + c.helloReceived = false } func (c *RemoteConnection) Close() error { @@ -237,15 +270,17 @@ func (c *RemoteConnection) Close() error { return nil } - if !c.closed.CompareAndSwap(false, true) { + if c.closeCtx.Err() != nil { // Already closed return nil } - c.closer.Close() + c.closeFunc() err1 := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}) err2 := c.conn.Close() c.conn = nil + c.connectedSince = time.Time{} + c.helloReceived = false if err1 != nil { return err1 } @@ -273,9 +308,10 @@ func (c *RemoteConnection) SendMessage(msg *signaling.ProxyClientMessage) error c.mu.Lock() defer c.mu.Unlock() - return c.sendMessageLocked(context.Background(), msg) + return c.sendMessageLocked(c.closeCtx, msg) } +// +checklocks:c.mu func (c *RemoteConnection) deferMessage(ctx context.Context, msg *signaling.ProxyClientMessage) { c.pendingMessages = append(c.pendingMessages, msg) if ctx.Done() != nil { @@ -294,6 +330,7 @@ func (c *RemoteConnection) deferMessage(ctx context.Context, msg *signaling.Prox } } +// +checklocks:c.mu func (c *RemoteConnection) sendMessageLocked(ctx context.Context, msg *signaling.ProxyClientMessage) error { if c.conn == nil { // Defer until connected. @@ -313,7 +350,7 @@ func (c *RemoteConnection) sendMessageLocked(ctx context.Context, msg *signaling func (c *RemoteConnection) readPump(conn *websocket.Conn) { defer func() { - if !c.closed.Load() { + if c.closeCtx.Err() == nil { c.scheduleReconnect() } }() @@ -328,7 +365,7 @@ func (c *RemoteConnection) readPump(conn *websocket.Conn) { websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - if !errors.Is(err, net.ErrClosed) || !c.closed.Load() { + if !errors.Is(err, net.ErrClosed) || c.closeCtx.Err() == nil { log.Printf("Error reading from %s: %v", c, err) } } @@ -390,31 +427,35 @@ func (c *RemoteConnection) writePump() { c.reconnect() case <-ticker.C: c.sendPing() - case <-c.closer.C: + case <-c.closeCtx.Done(): return } } } func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { + c.mu.Lock() + defer c.mu.Unlock() + c.helloMsgId = "" switch msg.Type { case "error": if msg.Error.Code == "no_such_session" { log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c) c.sessionId = "" - if err := c.sendHello(); err != nil { + if err := c.sendHello(c.closeCtx); err != nil { log.Printf("Could not send hello request to %s: %s", c, err) - c.scheduleReconnect() + c.scheduleReconnectLocked() } return } log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error) - c.scheduleReconnect() + c.scheduleReconnectLocked() case "hello": resumed := c.sessionId == msg.Hello.SessionId c.sessionId = msg.Hello.SessionId + c.helloReceived = true country := "" if msg.Hello.Server != nil { if country = msg.Hello.Server.Country; country != "" && !signaling.IsValidCountry(country) { @@ -437,27 +478,38 @@ func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) { continue } - if err := c.sendMessageLocked(context.Background(), m); err != nil { + if err := c.sendMessageLocked(c.closeCtx, m); err != nil { log.Printf("Could not send pending message %+v to %s: %s", m, c, err) } } default: log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c) - c.scheduleReconnect() + c.scheduleReconnectLocked() } } -func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) { - if msg.Id != "" { - c.mu.Lock() - ch, found := c.messageCallbacks[msg.Id] - if found { - delete(c.messageCallbacks, msg.Id) - c.mu.Unlock() - ch <- msg - return - } +func (c *RemoteConnection) handleCallback(msg *signaling.ProxyServerMessage) bool { + if msg.Id == "" { + return false + } + + c.mu.Lock() + ch, found := c.messageCallbacks[msg.Id] + if !found { c.mu.Unlock() + return false + } + + delete(c.messageCallbacks, msg.Id) + c.mu.Unlock() + + ch <- msg + return true +} + +func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) { + if c.handleCallback(msg) { + return } switch msg.Type { @@ -467,7 +519,9 @@ func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) { log.Printf("Connection to %s was closed: %s", c, msg.Bye.Reason) if msg.Bye.Reason == "session_expired" { // Don't try to resume expired session. + c.mu.Lock() c.sessionId = "" + c.mu.Unlock() } c.scheduleReconnect() default: @@ -487,21 +541,31 @@ func (c *RemoteConnection) processEvent(msg *signaling.ProxyServerMessage) { } } -func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.ProxyClientMessage) (*signaling.ProxyServerMessage, error) { +func (c *RemoteConnection) sendMessageWithCallbackLocked(ctx context.Context, msg *signaling.ProxyClientMessage) (string, <-chan *signaling.ProxyServerMessage, error) { msg.Id = strconv.FormatInt(c.msgId.Add(1), 10) c.mu.Lock() defer c.mu.Unlock() - if err := c.sendMessageLocked(ctx, msg); err != nil { - return nil, err + msg.Id = "" + return "", nil, err } + ch := make(chan *signaling.ProxyServerMessage, 1) c.messageCallbacks[msg.Id] = ch - c.mu.Unlock() + return msg.Id, ch, nil +} + +func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.ProxyClientMessage) (*signaling.ProxyServerMessage, error) { + id, ch, err := c.sendMessageWithCallbackLocked(ctx, msg) + if err != nil { + return nil, err + } + defer func() { c.mu.Lock() - delete(c.messageCallbacks, msg.Id) + defer c.mu.Unlock() + delete(c.messageCallbacks, id) }() select { @@ -515,3 +579,15 @@ func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.Pr return response, nil } } + +func (c *RemoteConnection) SendBye() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.conn == nil { + return nil + } + + return c.sendMessageLocked(c.closeCtx, &signaling.ProxyClientMessage{ + Type: "bye", + }) +} diff --git a/proxy/proxy_remote_test.go b/proxy/proxy_remote_test.go new file mode 100644 index 0000000..75515f1 --- /dev/null +++ b/proxy/proxy_remote_test.go @@ -0,0 +1,212 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2025 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package main + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + signaling "github.com/strukturag/nextcloud-spreed-signaling" +) + +func (c *RemoteConnection) WaitForConnection(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + // Only used in tests, so a busy-loop should be fine. + for c.conn == nil || c.connectedSince.IsZero() || !c.helloReceived { + if err := ctx.Err(); err != nil { + return err + } + + c.mu.Unlock() + time.Sleep(time.Nanosecond) + c.mu.Lock() + } + + return nil +} + +func (c *RemoteConnection) WaitForDisconnect(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + initial := c.conn + if initial == nil { + return nil + } + + // Only used in tests, so a busy-loop should be fine. + for c.conn == initial { + if err := ctx.Err(); err != nil { + return err + } + + c.mu.Unlock() + time.Sleep(time.Nanosecond) + c.mu.Lock() + } + return nil +} + +func Test_ProxyRemoteConnectionReconnect(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + server, key, httpserver := newProxyServerForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + conn, err := NewRemoteConnection(server, httpserver.URL, TokenIdForTest, key, nil) + require.NoError(err) + t.Cleanup(func() { + assert.NoError(conn.SendBye()) + assert.NoError(conn.Close()) + }) + + assert.NoError(conn.WaitForConnection(ctx)) + + // Closing the connection will reconnect automatically + conn.mu.Lock() + c := conn.conn + conn.mu.Unlock() + assert.NoError(c.Close()) + assert.NoError(conn.WaitForDisconnect(ctx)) + assert.NoError(conn.WaitForConnection(ctx)) +} + +func Test_ProxyRemoteConnectionReconnectUnknownSession(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + server, key, httpserver := newProxyServerForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + conn, err := NewRemoteConnection(server, httpserver.URL, TokenIdForTest, key, nil) + require.NoError(err) + t.Cleanup(func() { + assert.NoError(conn.SendBye()) + assert.NoError(conn.Close()) + }) + + assert.NoError(conn.WaitForConnection(ctx)) + + // Closing the connection will reconnect automatically + conn.mu.Lock() + c := conn.conn + sessionId := conn.sessionId + conn.mu.Unlock() + var sid uint64 + server.IterateSessions(func(session *ProxySession) { + if session.PublicId() == sessionId { + sid = session.Sid() + } + }) + require.NotEqualValues(0, sid) + server.DeleteSession(sid) + if err := c.Close(); err != nil { + // If an error occurs while closing, it may only be "use of closed network + // connection" because the "DeleteSession" might have already closed the + // socket. + assert.ErrorIs(err, net.ErrClosed) + } + assert.NoError(conn.WaitForDisconnect(ctx)) + assert.NoError(conn.WaitForConnection(ctx)) + assert.NotEqual(sessionId, conn.SessionId()) +} + +func Test_ProxyRemoteConnectionReconnectExpiredSession(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + server, key, httpserver := newProxyServerForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + conn, err := NewRemoteConnection(server, httpserver.URL, TokenIdForTest, key, nil) + require.NoError(err) + t.Cleanup(func() { + assert.NoError(conn.SendBye()) + assert.NoError(conn.Close()) + }) + + assert.NoError(conn.WaitForConnection(ctx)) + + // Closing the connection will reconnect automatically + conn.mu.Lock() + sessionId := conn.sessionId + conn.mu.Unlock() + var session *ProxySession + server.IterateSessions(func(sess *ProxySession) { + if sess.PublicId() == sessionId { + session = sess + } + }) + require.NotNil(session) + session.Close() + assert.NoError(conn.WaitForDisconnect(ctx)) + assert.NoError(conn.WaitForConnection(ctx)) + assert.NotEqual(sessionId, conn.SessionId()) +} + +func Test_ProxyRemoteConnectionCreatePublisher(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + server, key, httpserver := newProxyServerForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + conn, err := NewRemoteConnection(server, httpserver.URL, TokenIdForTest, key, nil) + require.NoError(err) + t.Cleanup(func() { + assert.NoError(conn.SendBye()) + assert.NoError(conn.Close()) + }) + + publisherId := "the-publisher" + hostname := "the-hostname" + port := 1234 + rtcpPort := 2345 + + _, err = conn.RequestMessage(ctx, &signaling.ProxyClientMessage{ + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + ClientId: publisherId, + Hostname: hostname, + Port: port, + RtcpPort: rtcpPort, + }, + }) + assert.ErrorContains(err, UnknownClient.Error()) +} diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 89bbd9a..6190275 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -133,20 +133,25 @@ type ProxyServer struct { sid atomic.Uint64 cookie *signaling.SessionIdCodec - sessions map[uint64]*ProxySession sessionsLock sync.RWMutex + // +checklocks:sessionsLock + sessions map[uint64]*ProxySession - clients map[string]signaling.McuClient - clientIds map[string]string clientsLock sync.RWMutex + // +checklocks:clientsLock + clients map[string]signaling.McuClient + // +checklocks:clientsLock + clientIds map[string]string tokenId string tokenKey *rsa.PrivateKey - remoteTlsConfig *tls.Config + remoteTlsConfig *tls.Config // +checklocksignore: Only written to from constructor. remoteHostname string - remoteConnections map[string]*RemoteConnection remoteConnectionsLock sync.Mutex - remotePublishers map[string]map[*proxyRemotePublisher]bool + // +checklocks:remoteConnectionsLock + remoteConnections map[string]*RemoteConnection + // +checklocks:remoteConnectionsLock + remotePublishers map[string]map[*proxyRemotePublisher]bool } func IsPublicIP(IP net.IP) bool { @@ -925,6 +930,13 @@ func (s *ProxyServer) addRemotePublisher(publisher *proxyRemotePublisher) { log.Printf("Add remote publisher to %s", publisher.remoteUrl) } +func (s *ProxyServer) hasRemotePublishers() bool { + s.remoteConnectionsLock.Lock() + defer s.remoteConnectionsLock.Unlock() + + return len(s.remotePublishers) > 0 +} + func (s *ProxyServer) removeRemotePublisher(publisher *proxyRemotePublisher) { s.remoteConnectionsLock.Lock() defer s.remoteConnectionsLock.Unlock() @@ -1510,6 +1522,7 @@ func (s *ProxyServer) DeleteSession(id uint64) { s.deleteSessionLocked(id) } +// +checklocks:s.sessionsLock func (s *ProxyServer) deleteSessionLocked(id uint64) { if session, found := s.sessions[id]; found { delete(s.sessions, id) diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index d4e0be5..25b4311 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -89,7 +89,7 @@ func WaitForProxyServer(ctx context.Context, t *testing.T, proxy *ProxyServer) { case <-ctx.Done(): proxy.clientsLock.Lock() proxy.remoteConnectionsLock.Lock() - assert.Fail(t, "Error waiting for proxy to terminate", "clients %+v / sessions %+v / remoteConnections %+v: %+v", proxy.clients, proxy.sessions, proxy.remoteConnections, ctx.Err()) + assert.Fail(t, "Error waiting for proxy to terminate", "clients %+v / sessions %+v / remoteConnections %+v: %+v", clients, sessions, remoteConnections, ctx.Err()) proxy.remoteConnectionsLock.Unlock() proxy.clientsLock.Unlock() return @@ -626,6 +626,121 @@ func TestProxyCodecs(t *testing.T) { } } +type StreamTestMCU struct { + TestMCU + + streams []signaling.PublisherStream +} + +type StreamsTestPublisher struct { + TestMCUPublisher + + streams []signaling.PublisherStream +} + +func (m *StreamTestMCU) NewPublisher(ctx context.Context, listener signaling.McuListener, id signaling.PublicSessionId, sid string, streamType signaling.StreamType, settings signaling.NewPublisherSettings, initiator signaling.McuInitiator) (signaling.McuPublisher, error) { + return &StreamsTestPublisher{ + TestMCUPublisher: TestMCUPublisher{ + id: id, + sid: sid, + streamType: streamType, + }, + + streams: m.streams, + }, nil +} + +func (p *StreamsTestPublisher) GetStreams(ctx context.Context) ([]signaling.PublisherStream, error) { + return p.streams, nil +} + +func NewStreamTestMCU(t *testing.T, streams []signaling.PublisherStream) *StreamTestMCU { + return &StreamTestMCU{ + TestMCU: TestMCU{ + t: t, + }, + + streams: streams, + } +} + +func TestProxyStreams(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + streams := []signaling.PublisherStream{ + { + Mid: "0", + Mindex: 0, + Type: "audio", + Codec: "opus", + }, + { + Mid: "1", + Mindex: 1, + Type: "video", + Codec: "vp8", + }, + } + + mcu := NewStreamTestMCU(t, streams) + proxy.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-publisher", + StreamType: signaling.StreamTypeVideo, + }, + })) + + var clientId string + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + if assert.NotEmpty(message.Command.Id) { + clientId = message.Command.Id + } + } + } + + require.NotEmpty(clientId, "should have received publisher id") + + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "get-publisher-streams", + ClientId: clientId, + }, + })) + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + assert.Equal(clientId, message.Command.Id) + assert.Equal(streams, message.Command.Streams) + } + } +} + type RemoteSubscriberTestMCU struct { TestMCU @@ -648,12 +763,18 @@ type TestRemotePublisher struct { refcnt atomic.Int32 closed context.Context closeFunc context.CancelFunc + listener signaling.McuListener + controller signaling.RemotePublisherController } func (p *TestRemotePublisher) Id() string { return "id" } +func (p *TestRemotePublisher) PublisherId() signaling.PublicSessionId { + return "id" +} + func (p *TestRemotePublisher) Sid() string { return "sid" } @@ -669,6 +790,13 @@ func (p *TestRemotePublisher) MaxBitrate() int { func (p *TestRemotePublisher) Close(ctx context.Context) { if count := p.refcnt.Add(-1); assert.True(p.t, count >= 0) && count == 0 { p.closeFunc() + shortCtx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + // Won't be able to preform remote call to actually stop publishing. + if err := p.controller.StopPublishing(shortCtx, p); !errors.Is(err, context.DeadlineExceeded) { + assert.NoError(p.t, err) + } + p.listener.PublisherClosed(p) } } @@ -684,6 +812,13 @@ func (p *TestRemotePublisher) RtcpPort() int { return 2 } +func (p *TestRemotePublisher) SetMedia(mediaType signaling.MediaType) { +} + +func (p *TestRemotePublisher) HasMedia(mediaType signaling.MediaType) bool { + return false +} + func (m *RemoteSubscriberTestMCU) NewRemotePublisher(ctx context.Context, listener signaling.McuListener, controller signaling.RemotePublisherController, streamType signaling.StreamType) (signaling.McuRemotePublisher, error) { require.Nil(m.t, m.publisher) assert.EqualValues(m.t, "video", streamType) @@ -694,6 +829,8 @@ func (m *RemoteSubscriberTestMCU) NewRemotePublisher(ctx context.Context, listen streamType: streamType, closed: closeCtx, closeFunc: closeFunc, + listener: listener, + controller: controller, } m.publisher.refcnt.Add(1) return m.publisher, nil @@ -813,6 +950,8 @@ func TestProxyRemoteSubscriber(t *testing.T) { } } + assert.True(proxy.hasRemotePublishers()) + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ Id: "3456", Type: "command", @@ -841,6 +980,8 @@ func TestProxyRemoteSubscriber(t *testing.T) { assert.Fail("publisher was not closed") } } + + assert.False(proxy.hasRemotePublishers()) } func TestProxyCloseRemoteOnSessionClose(t *testing.T) { @@ -936,10 +1077,12 @@ func NewUnpublishRemoteTestMCU(t *testing.T) *UnpublishRemoteTestMCU { type UnpublishRemoteTestPublisher struct { TestMCUPublisher - t *testing.T + t *testing.T // +checklocksignore: Only written to from constructor. - mu sync.RWMutex - remoteId signaling.PublicSessionId + mu sync.RWMutex + // +checklocks:mu + remoteId signaling.PublicSessionId + // +checklocks:mu remoteData *remotePublisherData } diff --git a/proxy/proxy_session.go b/proxy/proxy_session.go index 308708d..c97acc5 100644 --- a/proxy/proxy_session.go +++ b/proxy/proxy_session.go @@ -53,20 +53,27 @@ type ProxySession struct { ctx context.Context closeFunc context.CancelFunc - clientLock sync.Mutex - client *ProxyClient + clientLock sync.Mutex + // +checklocks:clientLock + client *ProxyClient + // +checklocks:clientLock pendingMessages []*signaling.ProxyServerMessage publishersLock sync.Mutex - publishers map[string]signaling.McuPublisher - publisherIds map[signaling.McuPublisher]string + // +checklocks:publishersLock + publishers map[string]signaling.McuPublisher + // +checklocks:publishersLock + publisherIds map[signaling.McuPublisher]string subscribersLock sync.Mutex - subscribers map[string]signaling.McuSubscriber - subscriberIds map[signaling.McuSubscriber]string + // +checklocks:subscribersLock + subscribers map[string]signaling.McuSubscriber + // +checklocks:subscribersLock + subscriberIds map[signaling.McuSubscriber]string remotePublishersLock sync.Mutex - remotePublishers map[signaling.McuRemoteAwarePublisher]map[string]*remotePublisherData + // +checklocks:remotePublishersLock + remotePublishers map[signaling.McuRemoteAwarePublisher]map[string]*remotePublisherData } func NewProxySession(proxy *ProxyServer, sid uint64, id signaling.PublicSessionId) *ProxySession { @@ -301,6 +308,8 @@ func (s *ProxySession) DeletePublisher(publisher signaling.McuPublisher) string delete(s.publishers, id) delete(s.publisherIds, publisher) if rp, ok := publisher.(signaling.McuRemoteAwarePublisher); ok { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() delete(s.remotePublishers, rp) } go s.proxy.PublisherDeleted(publisher) @@ -363,8 +372,8 @@ func (s *ProxySession) clearRemotePublishers() { } func (s *ProxySession) clearSubscribers() { - s.publishersLock.Lock() - defer s.publishersLock.Unlock() + s.subscribersLock.Lock() + defer s.subscribersLock.Unlock() go func(subscribers map[string]signaling.McuSubscriber) { for id, subscriber := range subscribers { diff --git a/proxy/proxy_testclient_test.go b/proxy/proxy_testclient_test.go index a239f41..6b983b2 100644 --- a/proxy/proxy_testclient_test.go +++ b/proxy/proxy_testclient_test.go @@ -43,10 +43,11 @@ var ( type ProxyTestClient struct { t *testing.T - assert *assert.Assertions + assert *assert.Assertions // +checklocksignore: Only written to from constructor. require *require.Assertions - mu sync.Mutex + mu sync.Mutex + // +checklocks:mu conn *websocket.Conn messageChan chan []byte readErrorChan chan error diff --git a/proxy_config_etcd.go b/proxy_config_etcd.go index 68cc996..faf1582 100644 --- a/proxy_config_etcd.go +++ b/proxy_config_etcd.go @@ -35,12 +35,14 @@ import ( type proxyConfigEtcd struct { mu sync.Mutex - proxy McuProxy + proxy McuProxy // +checklocksignore: Only written to from constructor. client *EtcdClient keyPrefix string - keyInfos map[string]*ProxyInformationEtcd - urlToKey map[string]string + // +checklocks:mu + keyInfos map[string]*ProxyInformationEtcd + // +checklocks:mu + urlToKey map[string]string closeCtx context.Context closeFunc context.CancelFunc @@ -211,6 +213,7 @@ func (p *proxyConfigEtcd) EtcdKeyDeleted(client *EtcdClient, key string, prevVal p.removeEtcdProxyLocked(key) } +// +checklocks:p.mu func (p *proxyConfigEtcd) removeEtcdProxyLocked(key string) { info, found := p.keyInfos[key] if !found { diff --git a/proxy_config_static.go b/proxy_config_static.go index b3b2573..dcdaa59 100644 --- a/proxy_config_static.go +++ b/proxy_config_static.go @@ -43,9 +43,11 @@ type proxyConfigStatic struct { mu sync.Mutex proxy McuProxy - dnsMonitor *DnsMonitor + dnsMonitor *DnsMonitor + // +checklocks:mu dnsDiscovery bool + // +checklocks:mu connectionsMap map[string]*ipList } @@ -107,7 +109,7 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool } if dnsDiscovery { - p.connectionsMap[u] = &ipList{ + p.connectionsMap[u] = &ipList{ // +checklocksignore: Not supported for iter loops yet, see https://github.com/google/gvisor/issues/12176 hostname: parsed.Host, } continue @@ -124,7 +126,7 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool } } - p.connectionsMap[u] = &ipList{ + p.connectionsMap[u] = &ipList{ // +checklocksignore: Not supported for iter loops yet, see https://github.com/google/gvisor/issues/12176 hostname: parsed.Host, } } diff --git a/proxy_config_test.go b/proxy_config_test.go index 2a65b7c..817d268 100644 --- a/proxy_config_test.go +++ b/proxy_config_test.go @@ -52,10 +52,12 @@ type proxyConfigEvent struct { } type mcuProxyForConfig struct { - t *testing.T + t *testing.T + mu sync.Mutex + // +checklocks:mu expected []proxyConfigEvent - mu sync.Mutex - waiters []chan struct{} + // +checklocks:mu + waiters []chan struct{} } func newMcuProxyForConfig(t *testing.T) *mcuProxyForConfig { @@ -63,6 +65,8 @@ func newMcuProxyForConfig(t *testing.T) *mcuProxyForConfig { t: t, } t.Cleanup(func() { + proxy.mu.Lock() + defer proxy.mu.Unlock() assert.Empty(t, proxy.expected) }) return proxy @@ -83,20 +87,29 @@ func (p *mcuProxyForConfig) Expect(action string, url string, ips ...net.IP) { }) } -func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) { +func (p *mcuProxyForConfig) addWaiter() chan struct{} { p.t.Helper() p.mu.Lock() defer p.mu.Unlock() if len(p.expected) == 0 { - return + return nil } waiter := make(chan struct{}) p.waiters = append(p.waiters, waiter) - p.mu.Unlock() - defer p.mu.Lock() + return waiter +} + +func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) { + p.t.Helper() + + waiter := p.addWaiter() + if waiter == nil { + return + } + select { case <-ctx.Done(): assert.NoError(p.t, ctx.Err()) @@ -104,6 +117,32 @@ func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) { } } +func (p *mcuProxyForConfig) getWaitersIfEmpty() []chan struct{} { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.expected) != 0 { + return nil + } + + waiters := p.waiters + p.waiters = nil + return waiters +} + +func (p *mcuProxyForConfig) getExpectedEvent() *proxyConfigEvent { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.expected) == 0 { + return nil + } + + expected := p.expected[0] + p.expected = p.expected[1:] + return &expected +} + func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) { p.t.Helper() pc := make([]uintptr, 32) @@ -121,32 +160,24 @@ func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) { } } - p.mu.Lock() - defer p.mu.Unlock() - - if len(p.expected) == 0 { + expected := p.getExpectedEvent() + if expected == nil { assert.Fail(p.t, "no event expected", "received %+v from %s:%d", event, caller.File, caller.Line) return } - defer func() { - if len(p.expected) == 0 { - waiters := p.waiters - p.waiters = nil - p.mu.Unlock() - defer p.mu.Lock() - - for _, ch := range waiters { - ch <- struct{}{} - } - } - }() - - expected := p.expected[0] - p.expected = p.expected[1:] - if !reflect.DeepEqual(expected, *event) { + if !reflect.DeepEqual(expected, event) { assert.Fail(p.t, "wrong event", "expected %+v, received %+v from %s:%d", expected, event, caller.File, caller.Line) } + + waiters := p.getWaitersIfEmpty() + if len(waiters) == 0 { + return + } + + for _, ch := range waiters { + ch <- struct{}{} + } } func (p *mcuProxyForConfig) AddConnection(ignoreErrors bool, url string, ips ...net.IP) error { diff --git a/publisher_stats_counter.go b/publisher_stats_counter.go index ba8b293..10257bb 100644 --- a/publisher_stats_counter.go +++ b/publisher_stats_counter.go @@ -28,7 +28,9 @@ import ( type publisherStatsCounter struct { mu sync.Mutex + // +checklocks:mu streamTypes map[StreamType]bool + // +checklocks:mu subscribers map[string]bool } diff --git a/room.go b/room.go index 9f7a98f..372cf5b 100644 --- a/room.go +++ b/room.go @@ -69,17 +69,23 @@ type Room struct { events AsyncEvents backend *Backend + // +checklocks:mu properties json.RawMessage - closer *Closer - mu *sync.RWMutex + closer *Closer + mu *sync.RWMutex + // +checklocks:mu sessions map[PublicSessionId]Session - + // +checklocks:mu internalSessions map[*ClientSession]bool - virtualSessions map[*VirtualSession]bool - inCallSessions map[Session]bool - roomSessionData map[PublicSessionId]*RoomSessionData + // +checklocks:mu + virtualSessions map[*VirtualSession]bool + // +checklocks:mu + inCallSessions map[Session]bool + // +checklocks:mu + roomSessionData map[PublicSessionId]*RoomSessionData + // +checklocks:mu statsRoomSessionsCurrent *prometheus.GaugeVec // Users currently in the room @@ -590,6 +596,7 @@ func (r *Room) PublishSessionLeft(session Session) { } } +// +checklocksread:r.mu func (r *Room) getClusteredInternalSessionsRLocked() (internal map[PublicSessionId]*InternalSessionData, virtual map[PublicSessionId]*VirtualSessionData) { if r.hub.rpcClients == nil { return nil, nil diff --git a/room_ping.go b/room_ping.go index 8370c7c..3e9d708 100644 --- a/room_ping.go +++ b/room_ping.go @@ -70,6 +70,7 @@ type RoomPing struct { backend *BackendClient capabilities *Capabilities + // +checklocks:mu entries map[string]*pingEntries } diff --git a/roomsessions_builtin.go b/roomsessions_builtin.go index c39926f..02a53c4 100644 --- a/roomsessions_builtin.go +++ b/roomsessions_builtin.go @@ -30,9 +30,11 @@ import ( ) type BuiltinRoomSessions struct { + mu sync.RWMutex + // +checklocks:mu sessionIdToRoomSession map[PublicSessionId]RoomSessionId + // +checklocks:mu roomSessionToSessionid map[RoomSessionId]PublicSessionId - mu sync.RWMutex clients *GrpcClients } diff --git a/server/main.go b/server/main.go index 80474a8..879accb 100644 --- a/server/main.go +++ b/server/main.go @@ -93,7 +93,8 @@ func createTLSListener(addr string, certFile, keyFile string) (net.Listener, err } type Listeners struct { - mu sync.Mutex + mu sync.Mutex + // +checklocks:mu listeners []net.Listener } diff --git a/session.go b/session.go index a27efea..0a188e8 100644 --- a/session.go +++ b/session.go @@ -44,7 +44,7 @@ var ( // DefaultPermissionOverrides contains permission overrides for users where // no permissions have been set by the server. If a permission is not set in // this map, it's assumed the user has that permission. - DefaultPermissionOverrides = map[Permission]bool{ + DefaultPermissionOverrides = map[Permission]bool{ // +checklocksignore: Global readonly variable. PERMISSION_HIDE_DISPLAYNAMES: false, } ) diff --git a/single_notifier.go b/single_notifier.go index 921542a..2a41f7d 100644 --- a/single_notifier.go +++ b/single_notifier.go @@ -67,7 +67,9 @@ func (w *SingleWaiter) cancel() { type SingleNotifier struct { sync.Mutex - waiter *SingleWaiter + // +checklocks:Mutex + waiter *SingleWaiter + // +checklocks:Mutex waiters map[*SingleWaiter]bool } diff --git a/testclient_test.go b/testclient_test.go index ec00fde..8ce9081 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -202,7 +202,8 @@ type TestClient struct { hub *Hub server *httptest.Server - mu sync.Mutex + mu sync.Mutex + // +checklocks:mu conn *websocket.Conn localAddr net.Addr diff --git a/throttle.go b/throttle.go index 8c96eac..6e78582 100644 --- a/throttle.go +++ b/throttle.go @@ -94,7 +94,8 @@ type memoryThrottler struct { getNow func() time.Time doDelay func(context.Context, time.Duration) - mu sync.RWMutex + mu sync.RWMutex + // +checklocks:mu clients map[string]map[string][]throttleEntry closer *Closer diff --git a/transient_data.go b/transient_data.go index 7e6f7d8..f3fbf9f 100644 --- a/transient_data.go +++ b/transient_data.go @@ -35,11 +35,15 @@ type TransientListener interface { } type TransientData struct { - mu sync.Mutex - data api.StringMap + mu sync.Mutex + // +checklocks:mu + data api.StringMap + // +checklocks:mu listeners map[TransientListener]bool - timers map[string]*time.Timer - ttlCh chan<- struct{} + // +checklocks:mu + timers map[string]*time.Timer + // +checklocks:mu + ttlCh chan<- struct{} } // NewTransientData creates a new transient data container. @@ -47,6 +51,7 @@ func NewTransientData() *TransientData { return &TransientData{} } +// +checklocks:t.mu func (t *TransientData) sendMessageToListener(listener TransientListener, message *ServerMessage) { t.mu.Unlock() defer t.mu.Lock() @@ -54,6 +59,7 @@ func (t *TransientData) sendMessageToListener(listener TransientListener, messag listener.SendMessage(message) } +// +checklocks:t.mu func (t *TransientData) notifySet(key string, prev, value any) { msg := &ServerMessage{ Type: "transient", @@ -69,6 +75,7 @@ func (t *TransientData) notifySet(key string, prev, value any) { } } +// +checklocks:t.mu func (t *TransientData) notifyDeleted(key string, prev any) { msg := &ServerMessage{ Type: "transient", @@ -112,6 +119,7 @@ func (t *TransientData) RemoveListener(listener TransientListener) { delete(t.listeners, listener) } +// +checklocks:t.mu func (t *TransientData) updateTTL(key string, value any, ttl time.Duration) { if ttl <= 0 { delete(t.timers, key) @@ -120,6 +128,7 @@ func (t *TransientData) updateTTL(key string, value any, ttl time.Duration) { } } +// +checklocks:t.mu func (t *TransientData) removeAfterTTL(key string, value any, ttl time.Duration) { if ttl <= 0 { return @@ -147,6 +156,7 @@ func (t *TransientData) removeAfterTTL(key string, value any, ttl time.Duration) t.timers[key] = timer } +// +checklocks:t.mu func (t *TransientData) doSet(key string, value any, prev any, ttl time.Duration) { if t.data == nil { t.data = make(api.StringMap) @@ -210,6 +220,7 @@ func (t *TransientData) CompareAndSetTTL(key string, old, value any, ttl time.Du return true } +// +checklocks:t.mu func (t *TransientData) doRemove(key string, prev any) { delete(t.data, key) if old, found := t.timers[key]; found { @@ -243,6 +254,7 @@ func (t *TransientData) CompareAndRemove(key string, old any) bool { return t.compareAndRemove(key, old) } +// +checklocks:t.mu func (t *TransientData) compareAndRemove(key string, old any) bool { prev, found := t.data[key] if !found || !reflect.DeepEqual(prev, old) { diff --git a/transient_data_test.go b/transient_data_test.go index feea5fb..ca554fc 100644 --- a/transient_data_test.go +++ b/transient_data_test.go @@ -92,6 +92,7 @@ type MockTransientListener struct { sending chan struct{} done chan struct{} + // +checklocks:mu data *TransientData } diff --git a/virtualsession_test.go b/virtualsession_test.go index 2f741df..3e7ddf1 100644 --- a/virtualsession_test.go +++ b/virtualsession_test.go @@ -45,7 +45,7 @@ func TestVirtualSession(t *testing.T) { backend := &Backend{ id: "compat", } - room, err := hub.createRoom(roomId, emptyProperties, backend) + room, err := hub.CreateRoom(roomId, emptyProperties, backend) require.NoError(err) defer room.Close() @@ -229,7 +229,7 @@ func TestVirtualSessionActorInformation(t *testing.T) { backend := &Backend{ id: "compat", } - room, err := hub.createRoom(roomId, emptyProperties, backend) + room, err := hub.CreateRoom(roomId, emptyProperties, backend) require.NoError(err) defer room.Close() @@ -439,7 +439,7 @@ func TestVirtualSessionCustomInCall(t *testing.T) { backend := &Backend{ id: "compat", } - room, err := hub.createRoom(roomId, emptyProperties, backend) + room, err := hub.CreateRoom(roomId, emptyProperties, backend) require.NoError(err) defer room.Close() @@ -581,7 +581,7 @@ func TestVirtualSessionCleanup(t *testing.T) { backend := &Backend{ id: "compat", } - room, err := hub.createRoom(roomId, emptyProperties, backend) + room, err := hub.CreateRoom(roomId, emptyProperties, backend) require.NoError(err) defer room.Close()