Use gvisor checklocks for static lock analysis.

Also fix locking issues found along the way.

See https://github.com/google/gvisor/tree/master/tools/checklocks for details.
This commit is contained in:
Joachim Bauch 2025-09-26 17:20:17 +02:00
commit cfd8cf6718
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
61 changed files with 707 additions and 345 deletions

View file

@ -8,6 +8,7 @@ on:
- '.golangci.yml'
- '**.go'
- 'go.*'
- 'Makefile'
pull_request:
branches: [ master ]
paths:
@ -15,6 +16,7 @@ on:
- '.golangci.yml'
- '**.go'
- 'go.*'
- 'Makefile'
permissions:
contents: read
@ -53,6 +55,20 @@ jobs:
run: |
GOEXPERIMENT=synctest go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -test ./...
checklocks:
name: checklocks
runs-on: ubuntu-latest
continue-on-error: true
steps:
- uses: actions/checkout@v5
- uses: actions/setup-go@v6
with:
go-version: "1.24"
- name: checklocks
run: |
make checklocks
dependencies:
name: dependencies
runs-on: ubuntu-latest

View file

@ -85,6 +85,9 @@ $(GOPATHBIN)/protoc-gen-go: go.mod go.sum
$(GOPATHBIN)/protoc-gen-go-grpc: go.mod go.sum
$(GO) install google.golang.org/grpc/cmd/protoc-gen-go-grpc
$(GOPATHBIN)/checklocks: go.mod go.sum
$(GO) install gvisor.dev/gvisor/tools/checklocks/cmd/checklocks@go
continentmap.go:
$(CURDIR)/scripts/get_continent_map.py $@
@ -108,6 +111,9 @@ vet:
test: vet
GOEXPERIMENT=synctest $(GO) test -timeout $(TIMEOUT) $(TESTARGS) ./...
checklocks: $(GOPATHBIN)/checklocks
GOEXPERIMENT=synctest go vet -vettool=$(GOPATHBIN)/checklocks ./...
cover: vet
rm -f cover.out && \
GOEXPERIMENT=synctest $(GO) test -timeout $(TIMEOUT) -coverprofile cover.out ./...

View file

@ -52,7 +52,7 @@ const (
)
var (
ErrNoSdp = NewError("no_sdp", "Payload does not contain a SDP.")
ErrNoSdp = NewError("no_sdp", "Payload does not contain a SDP.") // +checklocksignore: Global readonly variable.
ErrInvalidSdp = NewError("invalid_sdp", "Payload does not contain a valid SDP.")
ErrNoCandidate = NewError("no_candidate", "Payload does not contain a candidate.")

View file

@ -72,6 +72,7 @@ func NewAsyncEvents(url string) (AsyncEvents, error) {
type asyncBackendRoomSubscriber struct {
mu sync.Mutex
// +checklocks:mu
listeners map[AsyncBackendRoomEventListener]bool
}
@ -107,6 +108,7 @@ func (s *asyncBackendRoomSubscriber) removeListener(listener AsyncBackendRoomEve
type asyncRoomSubscriber struct {
mu sync.Mutex
// +checklocks:mu
listeners map[AsyncRoomEventListener]bool
}
@ -142,6 +144,7 @@ func (s *asyncRoomSubscriber) removeListener(listener AsyncRoomEventListener) bo
type asyncUserSubscriber struct {
mu sync.Mutex
// +checklocks:mu
listeners map[AsyncUserEventListener]bool
}
@ -177,6 +180,7 @@ func (s *asyncUserSubscriber) removeListener(listener AsyncUserEventListener) bo
type asyncSessionSubscriber struct {
mu sync.Mutex
// +checklocks:mu
listeners map[AsyncSessionEventListener]bool
}

View file

@ -231,10 +231,14 @@ type asyncEventsNats struct {
mu sync.Mutex
client NatsClient
// +checklocks:mu
backendRoomSubscriptions map[string]*asyncBackendRoomSubscriberNats
roomSubscriptions map[string]*asyncRoomSubscriberNats
userSubscriptions map[string]*asyncUserSubscriberNats
sessionSubscriptions map[string]*asyncSessionSubscriberNats
// +checklocks:mu
roomSubscriptions map[string]*asyncRoomSubscriberNats
// +checklocks:mu
userSubscriptions map[string]*asyncUserSubscriberNats
// +checklocks:mu
sessionSubscriptions map[string]*asyncSessionSubscriberNats
}
func NewAsyncEventsNats(client NatsClient) (AsyncEvents, error) {

View file

@ -56,7 +56,8 @@ type Backend struct {
sessionLimit uint64
sessionsLock sync.Mutex
sessions map[PublicSessionId]bool
// +checklocks:sessionsLock
sessions map[PublicSessionId]bool
counted bool
}
@ -170,7 +171,8 @@ type BackendStorage interface {
}
type backendStorageCommon struct {
mu sync.RWMutex
mu sync.RWMutex
// +checklocks:mu
backends map[string][]*Backend
}

View file

@ -32,7 +32,7 @@ var (
Name: "session_limit",
Help: "The session limit of a backend",
}, []string{"backend"})
statsBackendLimitExceededTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
statsBackendLimitExceededTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ // +checklocksignore: Global readonly variable.
Namespace: "signaling",
Subsystem: "backend",
Name: "session_limit_exceeded_total",

View file

@ -573,7 +573,9 @@ func TestBackendConfiguration_EtcdCompat(t *testing.T) {
assert.Equal(secret3, string(backends[0].secret))
}
storage.mu.RLock()
_, found := storage.backends["domain1.invalid"]
storage.mu.RUnlock()
assert.False(found, "Should have removed host information")
}
@ -806,6 +808,8 @@ func TestBackendConfiguration_EtcdChangeUrls(t *testing.T) {
<-ch
checkStatsValue(t, statsBackendsCurrent, 0)
storage.mu.RLock()
_, found := storage.backends["domain1.invalid"]
storage.mu.RUnlock()
assert.False(found, "Should have removed host information")
}

View file

@ -628,7 +628,7 @@ func RunTestBackendServer_RoomUpdate(t *testing.T) {
emptyProperties := json.RawMessage("{}")
backend := hub.backend.GetBackend(u)
require.NotNil(backend, "Did not find backend")
room, err := hub.createRoom(roomId, emptyProperties, backend)
room, err := hub.CreateRoom(roomId, emptyProperties, backend)
require.NoError(err, "Could not create room")
defer room.Close()
@ -696,7 +696,7 @@ func RunTestBackendServer_RoomDelete(t *testing.T) {
emptyProperties := json.RawMessage("{}")
backend := hub.backend.GetBackend(u)
require.NotNil(backend, "Did not find backend")
_, err = hub.createRoom(roomId, emptyProperties, backend)
_, err = hub.CreateRoom(roomId, emptyProperties, backend)
require.NoError(err)
userid := "test-userid"

View file

@ -151,6 +151,7 @@ func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error)
func (s *backendStorageStatic) Close() {
}
// +checklocks:s.mu
func (s *backendStorageStatic) RemoveBackendsForHost(host string, seen map[string]seenState) {
if oldBackends := s.backends[host]; len(oldBackends) > 0 {
deleted := 0
@ -185,6 +186,7 @@ const (
seenDeleted
)
// +checklocks:s.mu
func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend, seen map[string]seenState) {
for existingIndex, existingBackend := range s.backends[host] {
found := false

View file

@ -57,16 +57,22 @@ const (
)
var (
ErrUnexpectedHttpStatus = errors.New("unexpected_http_status")
ErrUnexpectedHttpStatus = errors.New("unexpected_http_status") // +checklocksignore: Global readonly variable.
)
type capabilitiesEntry struct {
c *Capabilities
mu sync.RWMutex
nextUpdate time.Time
etag string
mu sync.RWMutex
// +checklocks:mu
c *Capabilities
// +checklocks:mu
nextUpdate time.Time
// +checklocks:mu
etag string
// +checklocks:mu
mustRevalidate bool
capabilities api.StringMap
// +checklocks:mu
capabilities api.StringMap
}
func newCapabilitiesEntry(c *Capabilities) *capabilitiesEntry {
@ -82,10 +88,12 @@ func (e *capabilitiesEntry) valid(now time.Time) bool {
return e.validLocked(now)
}
// +checklocksread:e.mu
func (e *capabilitiesEntry) validLocked(now time.Time) bool {
return e.nextUpdate.After(now)
}
// +checklocks:e.mu
func (e *capabilitiesEntry) updateRequest(r *http.Request) {
if e.etag != "" {
r.Header.Set("If-None-Match", e.etag)
@ -99,6 +107,7 @@ func (e *capabilitiesEntry) invalidate() {
e.nextUpdate = time.Now()
}
// +checklocks:e.mu
func (e *capabilitiesEntry) errorIfMustRevalidate(err error) (bool, error) {
if !e.mustRevalidate {
return false, nil
@ -238,9 +247,11 @@ type Capabilities struct {
// Can be overwritten by tests.
getNow func() time.Time
version string
pool *HttpClientPool
entries map[string]*capabilitiesEntry
version string
pool *HttpClientPool
// +checklocks:mu
entries map[string]*capabilitiesEntry
// +checklocks:mu
nextInvalidate map[string]time.Time
buffers BufferPool

View file

@ -26,8 +26,10 @@ import (
)
type ChannelWaiters struct {
mu sync.RWMutex
id uint64
mu sync.RWMutex
// +checklocks:mu
id uint64
// +checklocks:mu
waiters map[uint64]chan struct{}
}

View file

@ -130,7 +130,8 @@ type Client struct {
logRTT bool
handlerMu sync.RWMutex
handler ClientHandler
// +checklocks:handlerMu
handler ClientHandler
session atomic.Pointer[Session]
sessionId atomic.Pointer[PublicSessionId]

View file

@ -113,7 +113,7 @@ type MessagePayload struct {
}
type SignalingClient struct {
readyWg *sync.WaitGroup
readyWg *sync.WaitGroup // +checklocksignore: Only written to from constructor.
cookie *signaling.SessionIdCodec
conn *websocket.Conn
@ -123,10 +123,13 @@ type SignalingClient struct {
stopChan chan struct{}
lock sync.Mutex
lock sync.Mutex
// +checklocks:lock
privateSessionId signaling.PrivateSessionId
publicSessionId signaling.PublicSessionId
userId string
// +checklocks:lock
publicSessionId signaling.PublicSessionId
// +checklocks:lock
userId string
}
func NewSignalingClient(cookie *signaling.SessionIdCodec, url string, stats *Stats, readyWg *sync.WaitGroup, doneWg *sync.WaitGroup) (*SignalingClient, error) {

View file

@ -70,9 +70,11 @@ type ClientSession struct {
parseUserData func() (api.StringMap, error)
inCall Flags
inCall Flags
// +checklocks:mu
supportsPermissions bool
permissions map[Permission]bool
// +checklocks:mu
permissions map[Permission]bool
backend *Backend
backendUrl string
@ -80,30 +82,40 @@ type ClientSession struct {
mu sync.Mutex
// +checklocks:mu
client HandlerClient
room atomic.Pointer[Room]
roomJoinTime atomic.Int64
federation atomic.Pointer[FederationClient]
roomSessionIdLock sync.RWMutex
roomSessionId RoomSessionId
// +checklocks:roomSessionIdLock
roomSessionId RoomSessionId
publisherWaiters ChannelWaiters
publisherWaiters ChannelWaiters // +checklocksignore
publishers map[StreamType]McuPublisher
// +checklocks:mu
publishers map[StreamType]McuPublisher
// +checklocks:mu
subscribers map[StreamId]McuSubscriber
pendingClientMessages []*ServerMessage
hasPendingChat bool
// +checklocks:mu
pendingClientMessages []*ServerMessage
// +checklocks:mu
hasPendingChat bool
// +checklocks:mu
hasPendingParticipantsUpdate bool
// +checklocks:mu
virtualSessions map[*VirtualSession]bool
seenJoinedLock sync.Mutex
seenJoinedLock sync.Mutex
// +checklocks:seenJoinedLock
seenJoinedEvents map[PublicSessionId]bool
responseHandlersLock sync.Mutex
responseHandlers map[string]ResponseHandlerFunc
// +checklocks:responseHandlersLock
responseHandlers map[string]ResponseHandlerFunc
}
func NewClientSession(hub *Hub, privateId PrivateSessionId, publicId PublicSessionId, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) {
@ -197,6 +209,19 @@ func (s *ClientSession) HasPermission(permission Permission) bool {
return s.hasPermissionLocked(permission)
}
func (s *ClientSession) GetPermissions() []Permission {
s.mu.Lock()
defer s.mu.Unlock()
result := make([]Permission, len(s.permissions))
for p, ok := range s.permissions {
if ok {
result = append(result, p)
}
}
return result
}
// HasAnyPermission checks if the session has one of the passed permissions.
func (s *ClientSession) HasAnyPermission(permission ...Permission) bool {
if len(permission) == 0 {
@ -209,16 +234,16 @@ func (s *ClientSession) HasAnyPermission(permission ...Permission) bool {
return s.hasAnyPermissionLocked(permission...)
}
// +checklocks:s.mu
func (s *ClientSession) hasAnyPermissionLocked(permission ...Permission) bool {
if len(permission) == 0 {
return false
}
return slices.ContainsFunc(permission, func(p Permission) bool {
return s.hasPermissionLocked(p)
})
return slices.ContainsFunc(permission, s.hasPermissionLocked)
}
// +checklocks:s.mu
func (s *ClientSession) hasPermissionLocked(permission Permission) bool {
if !s.supportsPermissions {
// Old-style session that doesn't receive permissions from Nextcloud.
@ -340,6 +365,7 @@ func (s *ClientSession) getRoomJoinTime() time.Time {
return time.Unix(0, t)
}
// +checklocks:s.mu
func (s *ClientSession) releaseMcuObjects() {
if len(s.publishers) > 0 {
go func(publishers map[StreamType]McuPublisher) {
@ -489,6 +515,7 @@ func (s *ClientSession) LeaveRoomWithMessage(notify bool, message *ClientMessage
return s.doLeaveRoom(notify)
}
// +checklocks:s.mu
func (s *ClientSession) doLeaveRoom(notify bool) *Room {
room := s.GetRoom()
if room == nil {
@ -543,6 +570,7 @@ func (s *ClientSession) ClearClient(client HandlerClient) {
s.clearClientLocked(client)
}
// +checklocks:s.mu
func (s *ClientSession) clearClientLocked(client HandlerClient) {
if s.client == nil {
return
@ -563,6 +591,7 @@ func (s *ClientSession) GetClient() HandlerClient {
return s.getClientUnlocked()
}
// +checklocks:s.mu
func (s *ClientSession) getClientUnlocked() HandlerClient {
return s.client
}
@ -589,6 +618,7 @@ func (s *ClientSession) SetClient(client HandlerClient) HandlerClient {
return prev
}
// +checklocks:s.mu
func (s *ClientSession) sendOffer(client McuClient, sender PublicSessionId, streamType StreamType, offer api.StringMap) {
offer_message := &AnswerOfferMessage{
To: s.PublicId(),
@ -617,6 +647,7 @@ func (s *ClientSession) sendOffer(client McuClient, sender PublicSessionId, stre
s.sendMessageUnlocked(response_message)
}
// +checklocks:s.mu
func (s *ClientSession) sendCandidate(client McuClient, sender PublicSessionId, streamType StreamType, candidate any) {
candidate_message := &AnswerOfferMessage{
To: s.PublicId(),
@ -647,6 +678,7 @@ func (s *ClientSession) sendCandidate(client McuClient, sender PublicSessionId,
s.sendMessageUnlocked(response_message)
}
// +checklocks:s.mu
func (s *ClientSession) sendMessageUnlocked(message *ServerMessage) bool {
if c := s.getClientUnlocked(); c != nil {
if c.SendMessage(message) {
@ -768,6 +800,7 @@ func (e *PermissionError) Error() string {
return fmt.Sprintf("permission \"%s\" not found", e.permission)
}
// +checklocks:s.mu
func (s *ClientSession) isSdpAllowedToSendLocked(sdp *sdp.SessionDescription) (MediaType, error) {
if sdp == nil {
// Should have already been checked when data was validated.
@ -832,6 +865,7 @@ func (s *ClientSession) CheckOfferType(streamType StreamType, data *MessageClien
return s.checkOfferTypeLocked(streamType, data)
}
// +checklocks:s.mu
func (s *ClientSession) checkOfferTypeLocked(streamType StreamType, data *MessageClientMessageData) (MediaType, error) {
if streamType == StreamTypeScreen {
if !s.hasPermissionLocked(PERMISSION_MAY_PUBLISH_SCREEN) {
@ -893,6 +927,8 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
if err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.publishers == nil {
s.publishers = make(map[StreamType]McuPublisher)
}
@ -915,6 +951,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
return publisher, nil
}
// +checklocks:s.mu
func (s *ClientSession) getPublisherLocked(streamType StreamType) McuPublisher {
return s.publishers[streamType]
}
@ -1120,6 +1157,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
s.SendMessage(serverMessage)
}
// +checklocks:s.mu
func (s *ClientSession) storePendingMessage(message *ServerMessage) {
if message.IsChatRefresh() {
if s.hasPendingChat {

View file

@ -27,7 +27,8 @@ import (
type ConcurrentMap[K comparable, V any] struct {
mu sync.RWMutex
d map[K]V
// +checklocks:mu
d map[K]V // +checklocksignore: Not supported yet, see https://github.com/google/gvisor/issues/11671
}
func (m *ConcurrentMap[K, V]) Set(key K, value V) {

View file

@ -57,11 +57,28 @@ type dnsMonitorEntry struct {
hostname string
hostIP net.IP
mu sync.Mutex
ips []net.IP
mu sync.Mutex
// +checklocks:mu
ips []net.IP
// +checklocks:mu
entries map[*DnsMonitorEntry]bool
}
func (e *dnsMonitorEntry) clearRemoved() bool {
e.mu.Lock()
defer e.mu.Unlock()
deleted := false
for entry := range e.entries {
if entry.entry.Load() == nil {
delete(e.entries, entry)
deleted = true
}
}
return deleted && len(e.entries) == 0
}
func (e *dnsMonitorEntry) setIPs(ips []net.IP, fromIP bool) {
e.mu.Lock()
defer e.mu.Unlock()
@ -134,6 +151,7 @@ func (e *dnsMonitorEntry) removeEntry(entry *DnsMonitorEntry) bool {
return len(e.entries) == 0
}
// +checklocks:e.mu
func (e *dnsMonitorEntry) runCallbacks(all []net.IP, add []net.IP, keep []net.IP, remove []net.IP) {
for entry := range e.entries {
entry.callback(entry, all, add, keep, remove)
@ -247,7 +265,7 @@ func (m *DnsMonitor) Remove(entry *DnsMonitorEntry) {
m.hasRemoved.Store(true)
return
}
defer m.mu.Unlock()
defer m.mu.Unlock() // +checklocksforce: only executed if the TryLock above succeeded.
e, found := m.hostnames[oldEntry.hostname]
if !found {
@ -268,15 +286,7 @@ func (m *DnsMonitor) clearRemoved() {
defer m.mu.Unlock()
for hostname, entry := range m.hostnames {
deleted := false
for e := range entry.entries {
if e.entry.Load() == nil {
delete(entry.entries, e)
deleted = true
}
}
if deleted && len(entry.entries) == 0 {
if entry.clearRemoved() {
delete(m.hostnames, hostname)
}
}

View file

@ -38,6 +38,7 @@ import (
type mockDnsLookup struct {
sync.RWMutex
// +checklocks:RWMutex
ips map[string][]net.IP
}
@ -118,14 +119,16 @@ func (r *dnsMonitorReceiverRecord) String() string {
}
var (
expectNone = &dnsMonitorReceiverRecord{}
expectNone = &dnsMonitorReceiverRecord{} // +checklocksignore: Global readonly variable.
)
type dnsMonitorReceiver struct {
sync.Mutex
t *testing.T
t *testing.T
// +checklocks:Mutex
expected *dnsMonitorReceiverRecord
// +checklocks:Mutex
received *dnsMonitorReceiverRecord
}
@ -328,13 +331,15 @@ func TestDnsMonitorNoLookupIfEmpty(t *testing.T) {
type deadlockMonitorReceiver struct {
t *testing.T
monitor *DnsMonitor
monitor *DnsMonitor // +checklocksignore: Only written to from constructor.
mu sync.RWMutex
wg sync.WaitGroup
entry *DnsMonitorEntry
started chan struct{}
// +checklocks:mu
entry *DnsMonitorEntry
started chan struct{}
// +checklocks:mu
triggered bool
closed atomic.Bool
}

View file

@ -55,8 +55,9 @@ type EtcdClientWatcher interface {
type EtcdClient struct {
compatSection string
mu sync.Mutex
client atomic.Value
mu sync.Mutex
client atomic.Value
// +checklocks:mu
listeners map[EtcdClientListener]bool
}

View file

@ -98,6 +98,7 @@ type FederationClient struct {
resumeId PrivateSessionId
hello atomic.Pointer[HelloServerMessage]
// +checklocks:helloMu
pendingMessages []*ClientMessage
closeOnLeave atomic.Bool

View file

@ -32,7 +32,7 @@ import (
"net/url"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/dlintw/goconf"
@ -59,12 +59,11 @@ type GeoLookup struct {
url string
isFile bool
client http.Client
mu sync.Mutex
lastModifiedHeader string
lastModifiedTime time.Time
lastModifiedHeader atomic.Value
lastModifiedTime atomic.Int64
reader *maxminddb.Reader
reader atomic.Pointer[maxminddb.Reader]
}
func NewGeoLookupFromUrl(url string) (*GeoLookup, error) {
@ -87,12 +86,9 @@ func NewGeoLookupFromFile(filename string) (*GeoLookup, error) {
}
func (g *GeoLookup) Close() {
g.mu.Lock()
if g.reader != nil {
g.reader.Close()
g.reader = nil
if reader := g.reader.Swap(nil); reader != nil {
reader.Close()
}
g.mu.Unlock()
}
func (g *GeoLookup) Update() error {
@ -109,7 +105,7 @@ func (g *GeoLookup) updateFile() error {
return err
}
if info.ModTime().Equal(g.lastModifiedTime) {
if info.ModTime().UnixNano() == g.lastModifiedTime.Load() {
return nil
}
@ -125,13 +121,11 @@ func (g *GeoLookup) updateFile() error {
metadata := reader.Metadata
log.Printf("Using %s GeoIP database from %s (built on %s)", metadata.DatabaseType, g.url, time.Unix(int64(metadata.BuildEpoch), 0).UTC())
g.mu.Lock()
if g.reader != nil {
g.reader.Close()
if old := g.reader.Swap(reader); old != nil {
old.Close()
}
g.reader = reader
g.lastModifiedTime = info.ModTime()
g.mu.Unlock()
g.lastModifiedTime.Store(info.ModTime().UnixNano())
return nil
}
@ -140,8 +134,8 @@ func (g *GeoLookup) updateUrl() error {
if err != nil {
return err
}
if g.lastModifiedHeader != "" {
request.Header.Add("If-Modified-Since", g.lastModifiedHeader)
if header := g.lastModifiedHeader.Load(); header != nil {
request.Header.Add("If-Modified-Since", header.(string))
}
response, err := g.client.Do(request)
if err != nil {
@ -210,13 +204,11 @@ func (g *GeoLookup) updateUrl() error {
metadata := reader.Metadata
log.Printf("Using %s GeoIP database from %s (built on %s)", metadata.DatabaseType, g.url, time.Unix(int64(metadata.BuildEpoch), 0).UTC())
g.mu.Lock()
if g.reader != nil {
g.reader.Close()
if old := g.reader.Swap(reader); old != nil {
old.Close()
}
g.reader = reader
g.lastModifiedHeader = response.Header.Get("Last-Modified")
g.mu.Unlock()
g.lastModifiedHeader.Store(response.Header.Get("Last-Modified"))
return nil
}
@ -227,14 +219,12 @@ func (g *GeoLookup) LookupCountry(ip net.IP) (string, error) {
} `maxminddb:"country"`
}
g.mu.Lock()
if g.reader == nil {
g.mu.Unlock()
reader := g.reader.Load()
if reader == nil {
return "", ErrDatabaseNotInitialized
}
err := g.reader.Lookup(ip, &record)
g.mu.Unlock()
if err != nil {
if err := reader.Lookup(ip, &record); err != nil {
return "", err
}

View file

@ -414,14 +414,18 @@ type GrpcClients struct {
mu sync.RWMutex
version string
// +checklocks:mu
clientsMap map[string]*grpcClientsList
clients []*GrpcClient
// +checklocks:mu
clients []*GrpcClient
dnsMonitor *DnsMonitor
dnsMonitor *DnsMonitor
// +checklocks:mu
dnsDiscovery bool
etcdClient *EtcdClient
targetPrefix string
etcdClient *EtcdClient // +checklocksignore: Only written to from constructor.
targetPrefix string
// +checklocks:mu
targetInformation map[string]*GrpcTargetInformationEtcd
dialOptions atomic.Value // []grpc.DialOption
creds credentials.TransportCredentials
@ -432,7 +436,7 @@ type GrpcClients struct {
selfCheckWaitGroup sync.WaitGroup
closeCtx context.Context
closeFunc context.CancelFunc
closeFunc context.CancelFunc // +checklocksignore: No locking necessary.
}
func NewGrpcClients(config *goconf.ConfigFile, etcdClient *EtcdClient, dnsMonitor *DnsMonitor, version string) (*GrpcClients, error) {
@ -755,6 +759,9 @@ func (c *GrpcClients) onLookup(entry *DnsMonitorEntry, all []net.IP, added []net
}
func (c *GrpcClients) loadTargetsEtcd(config *goconf.ConfigFile, fromReload bool, opts ...grpc.DialOption) error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.etcdClient.IsConfigured() {
return fmt.Errorf("no etcd endpoints configured")
}
@ -894,6 +901,7 @@ func (c *GrpcClients) EtcdKeyDeleted(client *EtcdClient, key string, prevValue [
c.removeEtcdClientLocked(key)
}
// +checklocks:c.mu
func (c *GrpcClients) removeEtcdClientLocked(key string) {
info, found := c.targetInformation[key]
if !found {

View file

@ -26,7 +26,7 @@ import (
)
var (
statsGrpcClients = prometheus.NewGauge(prometheus.GaugeOpts{
statsGrpcClients = prometheus.NewGauge(prometheus.GaugeOpts{ // +checklocksignore: Global readonly variable.
Namespace: "signaling",
Subsystem: "grpc",
Name: "clients",

View file

@ -79,9 +79,10 @@ type HttpClientPool struct {
mu sync.Mutex
transport *http.Transport
clients map[string]*Pool
// +checklocks:mu
clients map[string]*Pool
maxConcurrentRequestsPerHost int
maxConcurrentRequestsPerHost int // +checklocksignore: Only written to from constructor.
}
func NewHttpClientPool(maxConcurrentRequestsPerHost int, skipVerify bool) (*HttpClientPool, error) {

54
hub.go
View file

@ -162,13 +162,17 @@ type Hub struct {
mu sync.RWMutex
ru sync.RWMutex
sid atomic.Uint64
clients map[uint64]HandlerClient
sid atomic.Uint64
// +checklocks:mu
clients map[uint64]HandlerClient
// +checklocks:mu
sessions map[uint64]Session
rooms map[string]*Room
// +checklocks:ru
rooms map[string]*Room
roomSessions RoomSessions
roomPing *RoomPing
roomSessions RoomSessions
roomPing *RoomPing
// +checklocks:mu
virtualSessions map[PublicSessionId]uint64
decodeCaches []*LruCache[*SessionIdData]
@ -179,12 +183,18 @@ type Hub struct {
allowSubscribeAnyStream bool
expiredSessions map[Session]time.Time
anonymousSessions map[*ClientSession]time.Time
// +checklocks:mu
expiredSessions map[Session]time.Time
// +checklocks:mu
anonymousSessions map[*ClientSession]time.Time
// +checklocks:mu
expectHelloClients map[HandlerClient]time.Time
dialoutSessions map[*ClientSession]bool
remoteSessions map[*RemoteSession]bool
federatedSessions map[*ClientSession]bool
// +checklocks:mu
dialoutSessions map[*ClientSession]bool
// +checklocks:mu
remoteSessions map[*RemoteSession]bool
// +checklocks:mu
federatedSessions map[*ClientSession]bool
backendTimeout time.Duration
backend *BackendClient
@ -748,6 +758,7 @@ func (h *Hub) CreateProxyToken(publisherId string) (string, error) {
return proxy.createToken(publisherId)
}
// +checklocks:h.mu
func (h *Hub) checkExpiredSessions(now time.Time) {
for session, expires := range h.expiredSessions {
if now.After(expires) {
@ -761,6 +772,7 @@ func (h *Hub) checkExpiredSessions(now time.Time) {
}
}
// +checklocks:h.mu
func (h *Hub) checkAnonymousSessions(now time.Time) {
for session, timeout := range h.anonymousSessions {
if now.After(timeout) {
@ -775,6 +787,7 @@ func (h *Hub) checkAnonymousSessions(now time.Time) {
}
}
// +checklocks:h.mu
func (h *Hub) checkInitialHello(now time.Time) {
for client, timeout := range h.expectHelloClients {
if now.After(timeout) {
@ -788,10 +801,11 @@ func (h *Hub) checkInitialHello(now time.Time) {
func (h *Hub) performHousekeeping(now time.Time) {
h.mu.Lock()
defer h.mu.Unlock()
h.checkExpiredSessions(now)
h.checkAnonymousSessions(now)
h.checkInitialHello(now)
h.mu.Unlock()
}
func (h *Hub) removeSession(session Session) (removed bool) {
@ -820,6 +834,7 @@ func (h *Hub) removeSession(session Session) (removed bool) {
return
}
// +checklocksread:h.mu
func (h *Hub) hasSessionsLocked(withInternal bool) bool {
if withInternal {
return len(h.sessions) > 0
@ -841,6 +856,7 @@ func (h *Hub) startWaitAnonymousSessionRoom(session *ClientSession) {
h.startWaitAnonymousSessionRoomLocked(session)
}
// +checklocks:h.mu
func (h *Hub) startWaitAnonymousSessionRoomLocked(session *ClientSession) {
if session.ClientType() == HelloClientTypeInternal {
// Internal clients don't need to join a room.
@ -1629,7 +1645,7 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo
} else {
response.Room = &RoomServerMessage{
RoomId: room.id,
Properties: room.properties,
Properties: room.Properties(),
}
}
return session.SendMessage(response)
@ -1764,7 +1780,7 @@ func (h *Hub) processRoom(sess Session, message *ClientMessage) {
NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{
Room: &RoomServerMessage{
RoomId: room.id,
Properties: room.properties,
Properties: room.Properties(),
},
}),
))
@ -1907,7 +1923,15 @@ func (h *Hub) removeRoom(room *Room) {
h.roomPing.DeleteRoom(room.Id())
}
func (h *Hub) createRoom(id string, properties json.RawMessage, backend *Backend) (*Room, error) {
func (h *Hub) CreateRoom(id string, properties json.RawMessage, backend *Backend) (*Room, error) {
h.ru.Lock()
defer h.ru.Unlock()
return h.createRoomLocked(id, properties, backend)
}
// +checklocks:h.ru
func (h *Hub) createRoomLocked(id string, properties json.RawMessage, backend *Backend) (*Room, error) {
// Note the write lock must be held.
room, err := NewRoom(id, properties, h, h.events, backend)
if err != nil {
@ -1947,7 +1971,7 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
r, found := h.rooms[internalRoomId]
if !found {
var err error
if r, err = h.createRoom(roomId, room.Room.Properties, session.Backend()); err != nil {
if r, err = h.createRoomLocked(roomId, room.Room.Properties, session.Backend()); err != nil {
h.ru.Unlock()
session.SendMessage(message.NewWrappedErrorServerMessage(err))
// The session (implicitly) left the room due to an error.

View file

@ -26,7 +26,7 @@ import (
)
var (
statsHubRoomsCurrent = prometheus.NewGaugeVec(prometheus.GaugeOpts{
statsHubRoomsCurrent = prometheus.NewGaugeVec(prometheus.GaugeOpts{ // +checklocksignore: Global readonly variable.
Namespace: "signaling",
Subsystem: "hub",
Name: "rooms",

View file

@ -466,6 +466,7 @@ func processRoomRequest(t *testing.T, w http.ResponseWriter, r *http.Request, re
var (
sessionRequestHander struct {
sync.Mutex
// +checklocks:Mutex
handlers map[*testing.T]func(*BackendClientSessionRequest)
}
)
@ -2843,8 +2844,8 @@ func TestInitialRoomPermissions(t *testing.T) {
session := hub.GetSessionByPublicId(hello.Hello.SessionId).(*ClientSession)
require.NotNil(session, "Session %s does not exist", hello.Hello.SessionId)
assert.True(session.HasPermission(PERMISSION_MAY_PUBLISH_AUDIO), "Session %s should have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_AUDIO, session.permissions)
assert.False(session.HasPermission(PERMISSION_MAY_PUBLISH_VIDEO), "Session %s should not have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_VIDEO, session.permissions)
assert.True(session.HasPermission(PERMISSION_MAY_PUBLISH_AUDIO), "Session %s should have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_AUDIO, session.GetPermissions())
assert.False(session.HasPermission(PERMISSION_MAY_PUBLISH_VIDEO), "Session %s should not have %s, got %+v", session.PublicId(), PERMISSION_MAY_PUBLISH_VIDEO, session.GetPermissions())
}
func TestJoinRoomSwitchClient(t *testing.T) {

View file

@ -238,15 +238,18 @@ type JanusGateway struct {
listener GatewayListener
// Sessions is a map of the currently active sessions to the gateway.
// +checklocks:Mutex
Sessions map[uint64]*JanusSession
// Access to the Sessions map should be synchronized with the Gateway.Lock()
// and Gateway.Unlock() methods provided by the embedded sync.Mutex.
sync.Mutex
// +checklocks:writeMu
conn *websocket.Conn
nextTransaction atomic.Uint64
transactions map[uint64]*transaction
// +checklocks:Mutex
transactions map[uint64]*transaction
closer *Closer
@ -592,6 +595,7 @@ type JanusSession struct {
Id uint64
// Handles is a map of plugin handles within this session
// +checklocks:Mutex
Handles map[uint64]*JanusHandle
// Access to the Handles map should be synchronized with the Session.Lock()

10
lru.go
View file

@ -32,10 +32,12 @@ type cacheEntry[T any] struct {
}
type LruCache[T any] struct {
size int
mu sync.Mutex
size int // +checklocksignore: Only written to from constructor.
mu sync.Mutex
// +checklocks:mu
entries *list.List
data map[string]*list.Element
// +checklocks:mu
data map[string]*list.Element
}
func NewLruCache[T any](size int) *LruCache[T] {
@ -88,6 +90,7 @@ func (c *LruCache[T]) Remove(key string) {
c.mu.Unlock()
}
// +checklocks:c.mu
func (c *LruCache[T]) removeOldestLocked() {
v := c.entries.Back()
if v != nil {
@ -101,6 +104,7 @@ func (c *LruCache[T]) RemoveOldest() {
c.mu.Unlock()
}
// +checklocks:c.mu
func (c *LruCache[T]) removeElement(e *list.Element) {
c.entries.Remove(e)
entry := e.Value.(*cacheEntry[T])

View file

@ -221,13 +221,16 @@ type mcuJanus struct {
closeChan chan struct{}
muClients sync.Mutex
clients map[clientInterface]bool
clientId atomic.Uint64
// +checklocks:muClients
clients map[clientInterface]bool
clientId atomic.Uint64
// +checklocks:mu
publishers map[StreamId]*mcuJanusPublisher
publisherCreated Notifier
publisherConnected Notifier
remotePublishers map[StreamId]*mcuJanusRemotePublisher
// +checklocks:mu
remotePublishers map[StreamId]*mcuJanusRemotePublisher
reconnectTimer *time.Timer
reconnectInterval time.Duration
@ -684,8 +687,6 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id Pu
streamType: streamType,
maxBitrate: maxBitrate,
handle: handle,
handleId: handle.Id,
closeChan: make(chan struct{}, 1),
deferred: make(chan func(), 64),
},
@ -693,6 +694,8 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id Pu
id: id,
settings: settings,
}
client.handle.Store(handle)
client.handleId.Store(handle.Id)
client.mcuJanusClient.handleEvent = client.handleEvent
client.mcuJanusClient.handleHangup = client.handleHangup
client.mcuJanusClient.handleDetached = client.handleDetached
@ -701,7 +704,7 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id Pu
client.mcuJanusClient.handleMedia = client.handleMedia
m.registerClient(client)
log.Printf("Publisher %s is using handle %d", client.id, client.handleId)
log.Printf("Publisher %s is using handle %d", client.id, handle.Id)
go client.run(handle, client.closeChan)
m.mu.Lock()
m.publishers[getStreamId(id, streamType)] = client
@ -781,13 +784,13 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ
streamType: streamType,
maxBitrate: pub.MaxBitrate(),
handle: handle,
handleId: handle.Id,
closeChan: make(chan struct{}, 1),
deferred: make(chan func(), 64),
},
publisher: publisher,
}
client.handle.Store(handle)
client.handleId.Store(handle.Id)
client.mcuJanusClient.handleEvent = client.handleEvent
client.mcuJanusClient.handleHangup = client.handleHangup
client.mcuJanusClient.handleDetached = client.handleDetached
@ -865,8 +868,6 @@ func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller Re
streamType: streamType,
maxBitrate: maxBitrate,
handle: handle,
handleId: handle.Id,
closeChan: make(chan struct{}, 1),
deferred: make(chan func(), 64),
},
@ -881,6 +882,8 @@ func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller Re
port: int(port),
rtcpPort: int(rtcp_port),
}
pub.handle.Store(handle)
pub.handleId.Store(handle.Id)
pub.mcuJanusClient.handleEvent = pub.handleEvent
pub.mcuJanusClient.handleHangup = pub.handleHangup
pub.mcuJanusClient.handleDetached = pub.handleDetached
@ -946,8 +949,6 @@ func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener
streamType: publisher.StreamType(),
maxBitrate: pub.MaxBitrate(),
handle: handle,
handleId: handle.Id,
closeChan: make(chan struct{}, 1),
deferred: make(chan func(), 64),
},
@ -956,6 +957,8 @@ func (m *mcuJanus) NewRemoteSubscriber(ctx context.Context, listener McuListener
}
client.remote.Store(pub)
pub.addRef()
client.handle.Store(handle)
client.handleId.Store(handle.Id)
client.mcuJanusClient.handleEvent = client.handleEvent
client.mcuJanusClient.handleHangup = client.handleHangup
client.mcuJanusClient.handleDetached = client.handleDetached

View file

@ -27,6 +27,7 @@ import (
"reflect"
"strconv"
"sync"
"sync/atomic"
"github.com/notedit/janus-go"
@ -45,8 +46,8 @@ type mcuJanusClient struct {
streamType StreamType
maxBitrate int
handle *JanusHandle
handleId uint64
handle atomic.Pointer[JanusHandle]
handleId atomic.Uint64
closeChan chan struct{}
deferred chan func()
@ -81,8 +82,7 @@ func (c *mcuJanusClient) SendMessage(ctx context.Context, message *MessageClient
}
func (c *mcuJanusClient) closeClient(ctx context.Context) bool {
if handle := c.handle; handle != nil {
c.handle = nil
if handle := c.handle.Swap(nil); handle != nil {
close(c.closeChan)
if _, err := handle.Detach(ctx); err != nil {
if e, ok := err.(*janus.ErrorMsg); !ok || e.Err.Code != JANUS_ERROR_HANDLE_NOT_FOUND {
@ -127,7 +127,7 @@ loop:
}
func (c *mcuJanusClient) sendOffer(ctx context.Context, offer api.StringMap, callback func(error, api.StringMap)) {
handle := c.handle
handle := c.handle.Load()
if handle == nil {
callback(ErrNotConnected, nil)
return
@ -149,7 +149,7 @@ func (c *mcuJanusClient) sendOffer(ctx context.Context, offer api.StringMap, cal
}
func (c *mcuJanusClient) sendAnswer(ctx context.Context, answer api.StringMap, callback func(error, api.StringMap)) {
handle := c.handle
handle := c.handle.Load()
if handle == nil {
callback(ErrNotConnected, nil)
return
@ -169,7 +169,7 @@ func (c *mcuJanusClient) sendAnswer(ctx context.Context, answer api.StringMap, c
}
func (c *mcuJanusClient) sendCandidate(ctx context.Context, candidate any, callback func(error, api.StringMap)) {
handle := c.handle
handle := c.handle.Load()
if handle == nil {
callback(ErrNotConnected, nil)
return
@ -191,7 +191,7 @@ func (c *mcuJanusClient) handleTrickle(event *TrickleMsg) {
}
func (c *mcuJanusClient) selectStream(ctx context.Context, stream *streamSelection, callback func(error, api.StringMap)) {
handle := c.handle
handle := c.handle.Load()
if handle == nil {
callback(ErrNotConnected, nil)
return

View file

@ -67,38 +67,38 @@ func (p *mcuJanusPublisher) handleEvent(event *janus.EventMsg) {
ctx := context.TODO()
switch videoroom {
case "destroyed":
log.Printf("Publisher %d: associated room has been destroyed, closing", p.handleId)
log.Printf("Publisher %d: associated room has been destroyed, closing", p.handleId.Load())
go p.Close(ctx)
case "slow_link":
// Ignore, processed through "handleSlowLink" in the general events.
default:
log.Printf("Unsupported videoroom publisher event in %d: %+v", p.handleId, event)
log.Printf("Unsupported videoroom publisher event in %d: %+v", p.handleId.Load(), event)
}
} else {
log.Printf("Unsupported publisher event in %d: %+v", p.handleId, event)
log.Printf("Unsupported publisher event in %d: %+v", p.handleId.Load(), event)
}
}
func (p *mcuJanusPublisher) handleHangup(event *janus.HangupMsg) {
log.Printf("Publisher %d received hangup (%s), closing", p.handleId, event.Reason)
log.Printf("Publisher %d received hangup (%s), closing", p.handleId.Load(), event.Reason)
go p.Close(context.Background())
}
func (p *mcuJanusPublisher) handleDetached(event *janus.DetachedMsg) {
log.Printf("Publisher %d received detached, closing", p.handleId)
log.Printf("Publisher %d received detached, closing", p.handleId.Load())
go p.Close(context.Background())
}
func (p *mcuJanusPublisher) handleConnected(event *janus.WebRTCUpMsg) {
log.Printf("Publisher %d received connected", p.handleId)
log.Printf("Publisher %d received connected", p.handleId.Load())
p.mcu.publisherConnected.Notify(string(getStreamId(p.id, p.streamType)))
}
func (p *mcuJanusPublisher) handleSlowLink(event *janus.SlowLinkMsg) {
if event.Uplink {
log.Printf("Publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost)
log.Printf("Publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost)
} else {
log.Printf("Publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost)
log.Printf("Publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost)
}
}
@ -129,18 +129,22 @@ func (p *mcuJanusPublisher) NotifyReconnected() {
return
}
p.handle = handle
p.handleId = handle.Id
if prev := p.handle.Swap(handle); prev != nil {
if _, err := prev.Detach(context.Background()); err != nil {
log.Printf("Error detaching old publisher handle %d: %s", prev.Id, err)
}
}
p.handleId.Store(handle.Id)
p.session = session
p.roomId = roomId
log.Printf("Publisher %s reconnected on handle %d", p.id, p.handleId)
log.Printf("Publisher %s reconnected on handle %d", p.id, p.handleId.Load())
}
func (p *mcuJanusPublisher) Close(ctx context.Context) {
notify := false
p.mu.Lock()
if handle := p.handle; handle != nil && p.roomId != 0 {
if handle := p.handle.Load(); handle != nil && p.roomId != 0 {
destroy_msg := api.StringMap{
"request": "destroy",
"room": p.roomId,
@ -399,6 +403,11 @@ func getPublisherRemoteId(id PublicSessionId, remoteId PublicSessionId, hostname
}
func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId PublicSessionId, hostname string, port int, rtcpPort int) error {
handle := p.handle.Load()
if handle == nil {
return ErrNotConnected
}
msg := api.StringMap{
"request": "publish_remotely",
"room": p.roomId,
@ -408,7 +417,7 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId PublicSe
"port": port,
"rtcp_port": rtcpPort,
}
response, err := p.handle.Request(ctx, msg)
response, err := handle.Request(ctx, msg)
if err != nil {
return err
}
@ -436,13 +445,18 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId PublicSe
}
func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId PublicSessionId, hostname string, port int, rtcpPort int) error {
handle := p.handle.Load()
if handle == nil {
return ErrNotConnected
}
msg := api.StringMap{
"request": "unpublish_remotely",
"room": p.roomId,
"publisher_id": streamTypeUserIds[p.streamType],
"remote_id": getPublisherRemoteId(p.id, remoteId, hostname, port, rtcpPort),
}
response, err := p.handle.Request(ctx, msg)
response, err := handle.Request(ctx, msg)
if err != nil {
return err
}

View file

@ -63,38 +63,38 @@ func (p *mcuJanusRemotePublisher) handleEvent(event *janus.EventMsg) {
ctx := context.TODO()
switch videoroom {
case "destroyed":
log.Printf("Remote publisher %d: associated room has been destroyed, closing", p.handleId)
log.Printf("Remote publisher %d: associated room has been destroyed, closing", p.handleId.Load())
go p.Close(ctx)
case "slow_link":
// Ignore, processed through "handleSlowLink" in the general events.
default:
log.Printf("Unsupported videoroom remote publisher event in %d: %+v", p.handleId, event)
log.Printf("Unsupported videoroom remote publisher event in %d: %+v", p.handleId.Load(), event)
}
} else {
log.Printf("Unsupported remote publisher event in %d: %+v", p.handleId, event)
log.Printf("Unsupported remote publisher event in %d: %+v", p.handleId.Load(), event)
}
}
func (p *mcuJanusRemotePublisher) handleHangup(event *janus.HangupMsg) {
log.Printf("Remote publisher %d received hangup (%s), closing", p.handleId, event.Reason)
log.Printf("Remote publisher %d received hangup (%s), closing", p.handleId.Load(), event.Reason)
go p.Close(context.Background())
}
func (p *mcuJanusRemotePublisher) handleDetached(event *janus.DetachedMsg) {
log.Printf("Remote publisher %d received detached, closing", p.handleId)
log.Printf("Remote publisher %d received detached, closing", p.handleId.Load())
go p.Close(context.Background())
}
func (p *mcuJanusRemotePublisher) handleConnected(event *janus.WebRTCUpMsg) {
log.Printf("Remote publisher %d received connected", p.handleId)
log.Printf("Remote publisher %d received connected", p.handleId.Load())
p.mcu.publisherConnected.Notify(string(getStreamId(p.id, p.streamType)))
}
func (p *mcuJanusRemotePublisher) handleSlowLink(event *janus.SlowLinkMsg) {
if event.Uplink {
log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost)
log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost)
} else {
log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost)
log.Printf("Remote publisher %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost)
}
}
@ -107,12 +107,16 @@ func (p *mcuJanusRemotePublisher) NotifyReconnected() {
return
}
p.handle = handle
p.handleId = handle.Id
if prev := p.handle.Swap(handle); prev != nil {
if _, err := prev.Detach(context.Background()); err != nil {
log.Printf("Error detaching old remote publisher handle %d: %s", prev.Id, err)
}
}
p.handleId.Store(handle.Id)
p.session = session
p.roomId = roomId
log.Printf("Remote publisher %s reconnected on handle %d", p.id, p.handleId)
log.Printf("Remote publisher %s reconnected on handle %d", p.id, p.handleId.Load())
}
func (p *mcuJanusRemotePublisher) Close(ctx context.Context) {
@ -125,8 +129,10 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) {
}
p.mu.Lock()
if handle := p.handle; handle != nil {
response, err := p.handle.Request(ctx, api.StringMap{
defer p.mu.Unlock()
if handle := p.handle.Load(); handle != nil {
response, err := handle.Request(ctx, api.StringMap{
"request": "remove_remote_publisher",
"room": p.roomId,
"id": streamTypeUserIds[p.streamType],
@ -154,5 +160,4 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) {
}
p.closeClient(ctx)
p.mu.Unlock()
}

View file

@ -41,7 +41,7 @@ func (p *mcuJanusRemoteSubscriber) handleEvent(event *janus.EventMsg) {
ctx := context.TODO()
switch videoroom {
case "destroyed":
log.Printf("Remote subscriber %d: associated room has been destroyed, closing", p.handleId)
log.Printf("Remote subscriber %d: associated room has been destroyed, closing", p.handleId.Load())
go p.Close(ctx)
case "event":
// Handle renegotiations, but ignore other events like selected
@ -53,33 +53,33 @@ func (p *mcuJanusRemoteSubscriber) handleEvent(event *janus.EventMsg) {
case "slow_link":
// Ignore, processed through "handleSlowLink" in the general events.
default:
log.Printf("Unsupported videoroom event %s for remote subscriber %d: %+v", videoroom, p.handleId, event)
log.Printf("Unsupported videoroom event %s for remote subscriber %d: %+v", videoroom, p.handleId.Load(), event)
}
} else {
log.Printf("Unsupported event for remote subscriber %d: %+v", p.handleId, event)
log.Printf("Unsupported event for remote subscriber %d: %+v", p.handleId.Load(), event)
}
}
func (p *mcuJanusRemoteSubscriber) handleHangup(event *janus.HangupMsg) {
log.Printf("Remote subscriber %d received hangup (%s), closing", p.handleId, event.Reason)
log.Printf("Remote subscriber %d received hangup (%s), closing", p.handleId.Load(), event.Reason)
go p.Close(context.Background())
}
func (p *mcuJanusRemoteSubscriber) handleDetached(event *janus.DetachedMsg) {
log.Printf("Remote subscriber %d received detached, closing", p.handleId)
log.Printf("Remote subscriber %d received detached, closing", p.handleId.Load())
go p.Close(context.Background())
}
func (p *mcuJanusRemoteSubscriber) handleConnected(event *janus.WebRTCUpMsg) {
log.Printf("Remote subscriber %d received connected", p.handleId)
log.Printf("Remote subscriber %d received connected", p.handleId.Load())
p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType)
}
func (p *mcuJanusRemoteSubscriber) handleSlowLink(event *janus.SlowLinkMsg) {
if event.Uplink {
log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost)
log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost)
} else {
log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost)
log.Printf("Remote subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost)
}
}
@ -98,12 +98,16 @@ func (p *mcuJanusRemoteSubscriber) NotifyReconnected() {
return
}
p.handle = handle
p.handleId = handle.Id
if prev := p.handle.Swap(handle); prev != nil {
if _, err := prev.Detach(context.Background()); err != nil {
log.Printf("Error detaching old remote subscriber handle %d: %s", prev.Id, err)
}
}
p.handleId.Store(handle.Id)
p.roomId = pub.roomId
p.sid = strconv.FormatUint(handle.Id, 10)
p.listener.SubscriberSidUpdated(p)
log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId)
log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId.Load())
}
func (p *mcuJanusRemoteSubscriber) Close(ctx context.Context) {

View file

@ -47,7 +47,7 @@ func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) {
ctx := context.TODO()
switch videoroom {
case "destroyed":
log.Printf("Subscriber %d: associated room has been destroyed, closing", p.handleId)
log.Printf("Subscriber %d: associated room has been destroyed, closing", p.handleId.Load())
go p.Close(ctx)
case "updated":
streams, ok := getPluginValue(event.Plugindata, pluginVideoRoom, "streams").([]any)
@ -64,7 +64,7 @@ func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) {
}
}
log.Printf("Subscriber %d: received updated event with no active media streams, closing", p.handleId)
log.Printf("Subscriber %d: received updated event with no active media streams, closing", p.handleId.Load())
go p.Close(ctx)
case "event":
// Handle renegotiations, but ignore other events like selected
@ -76,33 +76,33 @@ func (p *mcuJanusSubscriber) handleEvent(event *janus.EventMsg) {
case "slow_link":
// Ignore, processed through "handleSlowLink" in the general events.
default:
log.Printf("Unsupported videoroom event %s for subscriber %d: %+v", videoroom, p.handleId, event)
log.Printf("Unsupported videoroom event %s for subscriber %d: %+v", videoroom, p.handleId.Load(), event)
}
} else {
log.Printf("Unsupported event for subscriber %d: %+v", p.handleId, event)
log.Printf("Unsupported event for subscriber %d: %+v", p.handleId.Load(), event)
}
}
func (p *mcuJanusSubscriber) handleHangup(event *janus.HangupMsg) {
log.Printf("Subscriber %d received hangup (%s), closing", p.handleId, event.Reason)
log.Printf("Subscriber %d received hangup (%s), closing", p.handleId.Load(), event.Reason)
go p.Close(context.Background())
}
func (p *mcuJanusSubscriber) handleDetached(event *janus.DetachedMsg) {
log.Printf("Subscriber %d received detached, closing", p.handleId)
log.Printf("Subscriber %d received detached, closing", p.handleId.Load())
go p.Close(context.Background())
}
func (p *mcuJanusSubscriber) handleConnected(event *janus.WebRTCUpMsg) {
log.Printf("Subscriber %d received connected", p.handleId)
log.Printf("Subscriber %d received connected", p.handleId.Load())
p.mcu.SubscriberConnected(p.Id(), p.publisher, p.streamType)
}
func (p *mcuJanusSubscriber) handleSlowLink(event *janus.SlowLinkMsg) {
if event.Uplink {
log.Printf("Subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId, event.Lost)
log.Printf("Subscriber %s (%d) is reporting %d lost packets on the uplink (Janus -> client)", p.listener.PublicId(), p.handleId.Load(), event.Lost)
} else {
log.Printf("Subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId, event.Lost)
log.Printf("Subscriber %s (%d) is reporting %d lost packets on the downlink (client -> Janus)", p.listener.PublicId(), p.handleId.Load(), event.Lost)
}
}
@ -121,12 +121,16 @@ func (p *mcuJanusSubscriber) NotifyReconnected() {
return
}
p.handle = handle
p.handleId = handle.Id
if prev := p.handle.Swap(handle); prev != nil {
if _, err := prev.Detach(context.Background()); err != nil {
log.Printf("Error detaching old subscriber handle %d: %s", prev.Id, err)
}
}
p.handleId.Store(handle.Id)
p.roomId = pub.roomId
p.sid = strconv.FormatUint(handle.Id, 10)
p.listener.SubscriberSidUpdated(p)
log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId)
log.Printf("Subscriber %d for publisher %s reconnected on handle %d", p.id, p.publisher, p.handleId.Load())
}
func (p *mcuJanusSubscriber) closeClient(ctx context.Context) bool {
@ -152,7 +156,7 @@ func (p *mcuJanusSubscriber) Close(ctx context.Context) {
}
func (p *mcuJanusSubscriber) joinRoom(ctx context.Context, stream *streamSelection, callback func(error, api.StringMap)) {
handle := p.handle
handle := p.handle.Load()
if handle == nil {
callback(ErrNotConnected, nil)
return
@ -210,15 +214,19 @@ retry:
return
}
p.handle = handle
p.handleId = handle.Id
if prev := p.handle.Swap(handle); prev != nil {
if _, err := prev.Detach(context.Background()); err != nil {
log.Printf("Error detaching old subscriber handle %d: %s", prev.Id, err)
}
}
p.handleId.Store(handle.Id)
p.roomId = pub.roomId
p.sid = strconv.FormatUint(handle.Id, 10)
p.listener.SubscriberSidUpdated(p)
p.closeChan = make(chan struct{}, 1)
statsSubscribersCurrent.WithLabelValues(string(p.streamType)).Inc()
go p.run(p.handle, p.closeChan)
log.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId)
go p.run(handle, p.closeChan)
log.Printf("Already connected subscriber %d for %s, leaving and re-joining on handle %d", p.id, p.streamType, p.handleId.Load())
goto retry
case JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM:
fallthrough
@ -258,7 +266,7 @@ retry:
}
func (p *mcuJanusSubscriber) update(ctx context.Context, stream *streamSelection, callback func(error, api.StringMap)) {
handle := p.handle
handle := p.handle.Load()
if handle == nil {
callback(ErrNotConnected, nil)
return

View file

@ -58,19 +58,27 @@ type TestJanusGateway struct {
sid atomic.Uint64
tid atomic.Uint64
hid atomic.Uint64
rid atomic.Uint64
hid atomic.Uint64 // +checklocksignore: Atomic
rid atomic.Uint64 // +checklocksignore: Atomic
mu sync.Mutex
sessions map[uint64]*JanusSession
// +checklocks:mu
sessions map[uint64]*JanusSession
// +checklocks:mu
transactions map[uint64]*transaction
handles map[uint64]*TestJanusHandle
rooms map[uint64]*TestJanusRoom
handlers map[string]TestJanusHandler
// +checklocks:mu
handles map[uint64]*TestJanusHandle
// +checklocks:mu
rooms map[uint64]*TestJanusRoom
// +checklocks:mu
handlers map[string]TestJanusHandler
attachCount atomic.Int32
joinCount atomic.Int32
// +checklocks:mu
attachCount int
// +checklocks:mu
joinCount int
// +checklocks:mu
handleRooms map[*TestJanusHandle]*TestJanusRoom
}
@ -142,6 +150,7 @@ func (g *TestJanusGateway) Close() error {
return nil
}
// +checklocks:g.mu
func (g *TestJanusGateway) processMessage(session *JanusSession, handle *TestJanusHandle, body api.StringMap, jsep api.StringMap) any {
request := body["request"].(string)
switch request {
@ -165,15 +174,18 @@ func (g *TestJanusGateway) processMessage(session *JanusSession, handle *TestJan
error_code := JANUS_OK
if body["ptype"] == "subscriber" {
if strings.Contains(g.t.Name(), "NoSuchRoom") {
if g.joinCount.Add(1) == 1 {
g.joinCount++
if g.joinCount == 1 {
error_code = JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM
}
} else if strings.Contains(g.t.Name(), "AlreadyJoined") {
if g.joinCount.Add(1) == 1 {
g.joinCount++
if g.joinCount == 1 {
error_code = JANUS_VIDEOROOM_ERROR_ALREADY_JOINED
}
} else if strings.Contains(g.t.Name(), "SubscriberTimeout") {
if g.joinCount.Add(1) == 1 {
g.joinCount++
if g.joinCount == 1 {
error_code = JANUS_VIDEOROOM_ERROR_NO_SUCH_FEED
}
}
@ -354,7 +366,8 @@ func (g *TestJanusGateway) processRequest(msg api.StringMap) any {
switch method {
case "attach":
if strings.Contains(g.t.Name(), "AlreadyJoinedAttachError") {
if g.attachCount.Add(1) == 4 {
g.attachCount++
if g.attachCount == 4 {
return &janus.ErrorMsg{
Err: janus.ErrorData{
Code: JANUS_ERROR_UNKNOWN,

View file

@ -351,10 +351,11 @@ type mcuProxyConnection struct {
closer *Closer
closedDone *Closer
closed atomic.Bool
conn *websocket.Conn
// +checklocks:mu
conn *websocket.Conn
helloProcessed atomic.Bool
connectedSince time.Time
connectedSince atomic.Int64
reconnectTimer *time.Timer
reconnectInterval atomic.Int64
shutdownScheduled atomic.Bool
@ -371,15 +372,20 @@ type mcuProxyConnection struct {
version atomic.Value
features atomic.Value
callbacks map[string]mcuProxyCallback
// +checklocks:mu
callbacks map[string]mcuProxyCallback
// +checklocks:mu
deferredCallbacks map[string]mcuProxyCallback
publishersLock sync.RWMutex
publishers map[string]*mcuProxyPublisher
publisherIds map[StreamId]PublicSessionId
// +checklocks:publishersLock
publishers map[string]*mcuProxyPublisher
// +checklocks:publishersLock
publisherIds map[StreamId]PublicSessionId
subscribersLock sync.RWMutex
subscribers map[string]*mcuProxySubscriber
// +checklocks:subscribersLock
subscribers map[string]*mcuProxySubscriber
}
func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP, token string) (*mcuProxyConnection, error) {
@ -486,7 +492,10 @@ func (c *mcuProxyConnection) GetStats() *mcuProxyConnectionStats {
c.mu.Lock()
if c.conn != nil {
result.Connected = true
result.Uptime = &c.connectedSince
if since := c.connectedSince.Load(); since != 0 {
t := time.UnixMicro(since)
result.Uptime = &t
}
load := c.Load()
result.Load = &load
shutdown := c.IsShutdownScheduled()
@ -705,6 +714,7 @@ func (c *mcuProxyConnection) close() {
if c.conn != nil {
c.conn.Close()
c.conn = nil
c.connectedSince.Store(0)
if c.trackClose.CompareAndSwap(true, false) {
statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Dec()
}
@ -810,9 +820,9 @@ func (c *mcuProxyConnection) reconnect() {
log.Printf("Connected to %s", c)
c.closed.Store(false)
c.helloProcessed.Store(false)
c.connectedSince.Store(time.Now().UnixMicro())
c.mu.Lock()
c.connectedSince = time.Now()
c.conn = conn
c.mu.Unlock()
@ -1157,6 +1167,7 @@ func (c *mcuProxyConnection) sendMessage(msg *ProxyClientMessage) error {
return c.sendMessageLocked(msg)
}
// +checklocks:c.mu
func (c *mcuProxyConnection) sendMessageLocked(msg *ProxyClientMessage) error {
if proxyDebugMessages {
log.Printf("Send message to %s: %+v", c, msg)
@ -1449,16 +1460,19 @@ type mcuProxy struct {
tokenKey *rsa.PrivateKey
config ProxyConfig
dialer *websocket.Dialer
connections []*mcuProxyConnection
dialer *websocket.Dialer
connectionsMu sync.RWMutex
// +checklocks:connectionsMu
connections []*mcuProxyConnection
// +checklocks:connectionsMu
connectionsMap map[string][]*mcuProxyConnection
connectionsMu sync.RWMutex
connRequests atomic.Int64
nextSort atomic.Int64
settings McuSettings
mu sync.RWMutex
mu sync.RWMutex
// +checklocks:mu
publishers map[StreamId]*mcuProxyConnection
publisherWaiters ChannelWaiters
@ -1615,6 +1629,13 @@ func (m *mcuProxy) createToken(subject string) (string, error) {
return tokenString, nil
}
func (m *mcuProxy) getConnections() []*mcuProxyConnection {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
return m.connections
}
func (m *mcuProxy) hasConnections() bool {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
@ -1835,7 +1856,10 @@ func (m *mcuProxy) GetServerInfoSfu() *BackendServerInfoSfu {
if c.IsConnected() {
proxy.Connected = true
proxy.Shutdown = internal.MakePtr(c.IsShutdownScheduled())
proxy.Uptime = &c.connectedSince
if since := c.connectedSince.Load(); since != 0 {
t := time.UnixMicro(since)
proxy.Uptime = &t
}
proxy.Version = c.Version()
proxy.Features = c.Features()
proxy.Country = c.Country()

View file

@ -204,7 +204,8 @@ type testProxyServerSubscriber struct {
type testProxyServerClient struct {
t *testing.T
server *TestProxyServerHandler
server *TestProxyServerHandler
// +checklocks:mu
ws *websocket.Conn
processMessage proxyServerClientHandler
@ -549,12 +550,15 @@ type TestProxyServerHandler struct {
upgrader *websocket.Upgrader
country string
mu sync.Mutex
load atomic.Int64
incoming atomic.Pointer[float64]
outgoing atomic.Pointer[float64]
clients map[PublicSessionId]*testProxyServerClient
publishers map[PublicSessionId]*testProxyServerPublisher
mu sync.Mutex
load atomic.Int64
incoming atomic.Pointer[float64]
outgoing atomic.Pointer[float64]
// +checklocks:mu
clients map[PublicSessionId]*testProxyServerClient
// +checklocks:mu
publishers map[PublicSessionId]*testProxyServerPublisher
// +checklocks:mu
subscribers map[string]*testProxyServerSubscriber
wakeupChan chan struct{}
@ -1039,8 +1043,8 @@ func Test_ProxyAddRemoveConnectionsDnsDiscovery(t *testing.T) {
},
}, 0)
if assert.NotNil(mcu.connections[0].ip) {
assert.True(ip1.Equal(mcu.connections[0].ip), "ip addresses differ: expected %s, got %s", ip1.String(), mcu.connections[0].ip.String())
if connections := mcu.getConnections(); assert.Len(connections, 1) && assert.NotNil(connections[0].ip) {
assert.True(ip1.Equal(connections[0].ip), "ip addresses differ: expected %s, got %s", ip1.String(), connections[0].ip.String())
}
dnsMonitor := mcu.config.(*proxyConfigStatic).dnsMonitor
@ -1744,8 +1748,9 @@ func Test_ProxySubscriberBandwidthOverload(t *testing.T) {
}
type mockGrpcServerHub struct {
proxy atomic.Pointer[mcuProxy]
sessionsLock sync.Mutex
proxy atomic.Pointer[mcuProxy]
sessionsLock sync.Mutex
// +checklocks:sessionsLock
sessionByPublicId map[PublicSessionId]Session
}

View file

@ -41,8 +41,10 @@ const (
)
type TestMCU struct {
mu sync.Mutex
publishers map[PublicSessionId]*TestMCUPublisher
mu sync.Mutex
// +checklocks:mu
publishers map[PublicSessionId]*TestMCUPublisher
// +checklocks:mu
subscribers map[string]*TestMCUSubscriber
}

View file

@ -32,10 +32,13 @@ import (
)
type LoopbackNatsClient struct {
mu sync.Mutex
mu sync.Mutex
// +checklocks:mu
subscriptions map[string]map[*loopbackNatsSubscription]bool
wakeup sync.Cond
// +checklocks:mu
wakeup sync.Cond
// +checklocks:mu
incoming list.List
}
@ -65,6 +68,7 @@ func (c *LoopbackNatsClient) processMessages() {
}
}
// +checklocks:c.mu
func (c *LoopbackNatsClient) processMessage(msg *nats.Msg) {
subs, found := c.subscriptions[msg.Subject]
if !found {

View file

@ -39,7 +39,9 @@ func (w *Waiter) Wait(ctx context.Context) error {
type Notifier struct {
sync.Mutex
waiters map[string]*Waiter
// +checklocks:Mutex
waiters map[string]*Waiter
// +checklocks:Mutex
waiterMap map[string]map[*Waiter]bool
}

View file

@ -58,13 +58,14 @@ const (
)
var (
ErrNotConnected = errors.New("not connected")
ErrNotConnected = errors.New("not connected") // +checklocksignore: Global readonly variable.
)
type RemoteConnection struct {
mu sync.Mutex
p *ProxyServer
url *url.URL
mu sync.Mutex
p *ProxyServer
url *url.URL
// +checklocks:mu
conn *websocket.Conn
closer *signaling.Closer
closed atomic.Bool
@ -73,15 +74,20 @@ type RemoteConnection struct {
tokenKey *rsa.PrivateKey
tlsConfig *tls.Config
// +checklocks:mu
connectedSince time.Time
reconnectTimer *time.Timer
reconnectInterval atomic.Int64
msgId atomic.Int64
msgId atomic.Int64
// +checklocks:mu
helloMsgId string
sessionId signaling.PublicSessionId
// +checklocks:mu
sessionId signaling.PublicSessionId
pendingMessages []*signaling.ProxyClientMessage
// +checklocks:mu
pendingMessages []*signaling.ProxyClientMessage
// +checklocks:mu
messageCallbacks map[string]chan *signaling.ProxyServerMessage
}
@ -151,20 +157,35 @@ func (c *RemoteConnection) reconnect() {
c.reconnectInterval.Store(int64(initialReconnectInterval))
if err := c.sendHello(); err != nil {
log.Printf("Error sending hello request to proxy at %s: %s", c, err)
if !c.sendReconnectHello() || !c.sendPing() {
c.scheduleReconnect()
return
}
if !c.sendPing() {
return
}
go c.readPump(conn)
}
func (c *RemoteConnection) sendReconnectHello() bool {
c.mu.Lock()
defer c.mu.Unlock()
if err := c.sendHello(context.Background()); err != nil {
log.Printf("Error sending hello request to proxy at %s: %s", c, err)
return false
}
return true
}
func (c *RemoteConnection) scheduleReconnect() {
c.mu.Lock()
defer c.mu.Unlock()
c.scheduleReconnectLocked()
}
// +checklocks:c.mu
func (c *RemoteConnection) scheduleReconnectLocked() {
if err := c.sendClose(); err != nil && err != ErrNotConnected {
log.Printf("Could not send close message to %s: %s", c, err)
}
@ -180,7 +201,8 @@ func (c *RemoteConnection) scheduleReconnect() {
c.reconnectInterval.Store(interval)
}
func (c *RemoteConnection) sendHello() error {
// +checklocks:c.mu
func (c *RemoteConnection) sendHello(ctx context.Context) error {
c.helloMsgId = strconv.FormatInt(c.msgId.Add(1), 10)
msg := &signaling.ProxyClientMessage{
Id: c.helloMsgId,
@ -200,7 +222,7 @@ func (c *RemoteConnection) sendHello() error {
msg.Hello.Token = tokenString
}
return c.SendMessage(msg)
return c.sendMessageLocked(ctx, msg)
}
func (c *RemoteConnection) sendClose() error {
@ -210,6 +232,7 @@ func (c *RemoteConnection) sendClose() error {
return c.sendCloseLocked()
}
// +checklocks:c.mu
func (c *RemoteConnection) sendCloseLocked() error {
if c.conn == nil {
return ErrNotConnected
@ -276,6 +299,7 @@ func (c *RemoteConnection) SendMessage(msg *signaling.ProxyClientMessage) error
return c.sendMessageLocked(context.Background(), msg)
}
// +checklocks:c.mu
func (c *RemoteConnection) deferMessage(ctx context.Context, msg *signaling.ProxyClientMessage) {
c.pendingMessages = append(c.pendingMessages, msg)
if ctx.Done() != nil {
@ -294,6 +318,7 @@ func (c *RemoteConnection) deferMessage(ctx context.Context, msg *signaling.Prox
}
}
// +checklocks:c.mu
func (c *RemoteConnection) sendMessageLocked(ctx context.Context, msg *signaling.ProxyClientMessage) error {
if c.conn == nil {
// Defer until connected.
@ -397,21 +422,24 @@ func (c *RemoteConnection) writePump() {
}
func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) {
c.mu.Lock()
defer c.mu.Unlock()
c.helloMsgId = ""
switch msg.Type {
case "error":
if msg.Error.Code == "no_such_session" {
log.Printf("Session %s could not be resumed on %s, registering new", c.sessionId, c)
c.sessionId = ""
if err := c.sendHello(); err != nil {
if err := c.sendHello(context.Background()); err != nil {
log.Printf("Could not send hello request to %s: %s", c, err)
c.scheduleReconnect()
c.scheduleReconnectLocked()
}
return
}
log.Printf("Hello connection to %s failed with %+v, reconnecting", c, msg.Error)
c.scheduleReconnect()
c.scheduleReconnectLocked()
case "hello":
resumed := c.sessionId == msg.Hello.SessionId
c.sessionId = msg.Hello.SessionId
@ -443,21 +471,32 @@ func (c *RemoteConnection) processHello(msg *signaling.ProxyServerMessage) {
}
default:
log.Printf("Received unsupported hello response %+v from %s, reconnecting", msg, c)
c.scheduleReconnect()
c.scheduleReconnectLocked()
}
}
func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) {
if msg.Id != "" {
c.mu.Lock()
ch, found := c.messageCallbacks[msg.Id]
if found {
delete(c.messageCallbacks, msg.Id)
c.mu.Unlock()
ch <- msg
return
}
func (c *RemoteConnection) handleCallback(msg *signaling.ProxyServerMessage) bool {
if msg.Id == "" {
return false
}
c.mu.Lock()
ch, found := c.messageCallbacks[msg.Id]
if !found {
c.mu.Unlock()
return false
}
delete(c.messageCallbacks, msg.Id)
c.mu.Unlock()
ch <- msg
return true
}
func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) {
if c.handleCallback(msg) {
return
}
switch msg.Type {
@ -467,7 +506,9 @@ func (c *RemoteConnection) processMessage(msg *signaling.ProxyServerMessage) {
log.Printf("Connection to %s was closed: %s", c, msg.Bye.Reason)
if msg.Bye.Reason == "session_expired" {
// Don't try to resume expired session.
c.mu.Lock()
c.sessionId = ""
c.mu.Unlock()
}
c.scheduleReconnect()
default:
@ -487,21 +528,31 @@ func (c *RemoteConnection) processEvent(msg *signaling.ProxyServerMessage) {
}
}
func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.ProxyClientMessage) (*signaling.ProxyServerMessage, error) {
func (c *RemoteConnection) sendMessageWithCallbackLocked(ctx context.Context, msg *signaling.ProxyClientMessage) (string, <-chan *signaling.ProxyServerMessage, error) {
msg.Id = strconv.FormatInt(c.msgId.Add(1), 10)
c.mu.Lock()
defer c.mu.Unlock()
if err := c.sendMessageLocked(ctx, msg); err != nil {
return nil, err
msg.Id = ""
return "", nil, err
}
ch := make(chan *signaling.ProxyServerMessage, 1)
c.messageCallbacks[msg.Id] = ch
c.mu.Unlock()
return msg.Id, ch, nil
}
func (c *RemoteConnection) RequestMessage(ctx context.Context, msg *signaling.ProxyClientMessage) (*signaling.ProxyServerMessage, error) {
id, ch, err := c.sendMessageWithCallbackLocked(ctx, msg)
if err != nil {
return nil, err
}
defer func() {
c.mu.Lock()
delete(c.messageCallbacks, msg.Id)
defer c.mu.Unlock()
delete(c.messageCallbacks, id)
}()
select {

View file

@ -133,20 +133,25 @@ type ProxyServer struct {
sid atomic.Uint64
cookie *signaling.SessionIdCodec
sessions map[uint64]*ProxySession
sessionsLock sync.RWMutex
// +checklocks:sessionsLock
sessions map[uint64]*ProxySession
clients map[string]signaling.McuClient
clientIds map[string]string
clientsLock sync.RWMutex
// +checklocks:clientsLock
clients map[string]signaling.McuClient
// +checklocks:clientsLock
clientIds map[string]string
tokenId string
tokenKey *rsa.PrivateKey
remoteTlsConfig *tls.Config
remoteTlsConfig *tls.Config // +checklocksignore: Only written to from constructor.
remoteHostname string
remoteConnections map[string]*RemoteConnection
remoteConnectionsLock sync.Mutex
remotePublishers map[string]map[*proxyRemotePublisher]bool
// +checklocks:remoteConnectionsLock
remoteConnections map[string]*RemoteConnection
// +checklocks:remoteConnectionsLock
remotePublishers map[string]map[*proxyRemotePublisher]bool
}
func IsPublicIP(IP net.IP) bool {
@ -1503,6 +1508,7 @@ func (s *ProxyServer) DeleteSession(id uint64) {
s.deleteSessionLocked(id)
}
// +checklocks:s.sessionsLock
func (s *ProxyServer) deleteSessionLocked(id uint64) {
if session, found := s.sessions[id]; found {
delete(s.sessions, id)

View file

@ -89,7 +89,7 @@ func WaitForProxyServer(ctx context.Context, t *testing.T, proxy *ProxyServer) {
case <-ctx.Done():
proxy.clientsLock.Lock()
proxy.remoteConnectionsLock.Lock()
assert.Fail(t, "Error waiting for proxy to terminate", "clients %+v / sessions %+v / remoteConnections %+v: %+v", proxy.clients, proxy.sessions, proxy.remoteConnections, ctx.Err())
assert.Fail(t, "Error waiting for proxy to terminate", "clients %+v / sessions %+v / remoteConnections %+v: %+v", clients, sessions, remoteConnections, ctx.Err())
proxy.remoteConnectionsLock.Unlock()
proxy.clientsLock.Unlock()
return
@ -936,10 +936,12 @@ func NewUnpublishRemoteTestMCU(t *testing.T) *UnpublishRemoteTestMCU {
type UnpublishRemoteTestPublisher struct {
TestMCUPublisher
t *testing.T
t *testing.T // +checklocksignore: Only written to from constructor.
mu sync.RWMutex
remoteId signaling.PublicSessionId
mu sync.RWMutex
// +checklocks:mu
remoteId signaling.PublicSessionId
// +checklocks:mu
remoteData *remotePublisherData
}

View file

@ -53,20 +53,27 @@ type ProxySession struct {
ctx context.Context
closeFunc context.CancelFunc
clientLock sync.Mutex
client *ProxyClient
clientLock sync.Mutex
// +checklocks:clientLock
client *ProxyClient
// +checklocks:clientLock
pendingMessages []*signaling.ProxyServerMessage
publishersLock sync.Mutex
publishers map[string]signaling.McuPublisher
publisherIds map[signaling.McuPublisher]string
// +checklocks:publishersLock
publishers map[string]signaling.McuPublisher
// +checklocks:publishersLock
publisherIds map[signaling.McuPublisher]string
subscribersLock sync.Mutex
subscribers map[string]signaling.McuSubscriber
subscriberIds map[signaling.McuSubscriber]string
// +checklocks:subscribersLock
subscribers map[string]signaling.McuSubscriber
// +checklocks:subscribersLock
subscriberIds map[signaling.McuSubscriber]string
remotePublishersLock sync.Mutex
remotePublishers map[signaling.McuRemoteAwarePublisher]map[string]*remotePublisherData
// +checklocks:remotePublishersLock
remotePublishers map[signaling.McuRemoteAwarePublisher]map[string]*remotePublisherData
}
func NewProxySession(proxy *ProxyServer, sid uint64, id signaling.PublicSessionId) *ProxySession {
@ -301,6 +308,8 @@ func (s *ProxySession) DeletePublisher(publisher signaling.McuPublisher) string
delete(s.publishers, id)
delete(s.publisherIds, publisher)
if rp, ok := publisher.(signaling.McuRemoteAwarePublisher); ok {
s.remotePublishersLock.Lock()
defer s.remotePublishersLock.Unlock()
delete(s.remotePublishers, rp)
}
go s.proxy.PublisherDeleted(publisher)
@ -363,8 +372,8 @@ func (s *ProxySession) clearRemotePublishers() {
}
func (s *ProxySession) clearSubscribers() {
s.publishersLock.Lock()
defer s.publishersLock.Unlock()
s.subscribersLock.Lock()
defer s.subscribersLock.Unlock()
go func(subscribers map[string]signaling.McuSubscriber) {
for id, subscriber := range subscribers {

View file

@ -43,10 +43,11 @@ var (
type ProxyTestClient struct {
t *testing.T
assert *assert.Assertions
assert *assert.Assertions // +checklocksignore: Only written to from constructor.
require *require.Assertions
mu sync.Mutex
mu sync.Mutex
// +checklocks:mu
conn *websocket.Conn
messageChan chan []byte
readErrorChan chan error

View file

@ -35,12 +35,14 @@ import (
type proxyConfigEtcd struct {
mu sync.Mutex
proxy McuProxy
proxy McuProxy // +checklocksignore: Only written to from constructor.
client *EtcdClient
keyPrefix string
keyInfos map[string]*ProxyInformationEtcd
urlToKey map[string]string
// +checklocks:mu
keyInfos map[string]*ProxyInformationEtcd
// +checklocks:mu
urlToKey map[string]string
closeCtx context.Context
closeFunc context.CancelFunc
@ -211,6 +213,7 @@ func (p *proxyConfigEtcd) EtcdKeyDeleted(client *EtcdClient, key string, prevVal
p.removeEtcdProxyLocked(key)
}
// +checklocks:p.mu
func (p *proxyConfigEtcd) removeEtcdProxyLocked(key string) {
info, found := p.keyInfos[key]
if !found {

View file

@ -43,9 +43,11 @@ type proxyConfigStatic struct {
mu sync.Mutex
proxy McuProxy
dnsMonitor *DnsMonitor
dnsMonitor *DnsMonitor
// +checklocks:mu
dnsDiscovery bool
// +checklocks:mu
connectionsMap map[string]*ipList
}
@ -107,7 +109,7 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool
}
if dnsDiscovery {
p.connectionsMap[u] = &ipList{
p.connectionsMap[u] = &ipList{ // +checklocksignore: Not supported for iter loops yet, see https://github.com/google/gvisor/issues/12176
hostname: parsed.Host,
}
continue
@ -124,7 +126,7 @@ func (p *proxyConfigStatic) configure(config *goconf.ConfigFile, fromReload bool
}
}
p.connectionsMap[u] = &ipList{
p.connectionsMap[u] = &ipList{ // +checklocksignore: Not supported for iter loops yet, see https://github.com/google/gvisor/issues/12176
hostname: parsed.Host,
}
}

View file

@ -52,10 +52,12 @@ type proxyConfigEvent struct {
}
type mcuProxyForConfig struct {
t *testing.T
t *testing.T
mu sync.Mutex
// +checklocks:mu
expected []proxyConfigEvent
mu sync.Mutex
waiters []chan struct{}
// +checklocks:mu
waiters []chan struct{}
}
func newMcuProxyForConfig(t *testing.T) *mcuProxyForConfig {
@ -63,6 +65,8 @@ func newMcuProxyForConfig(t *testing.T) *mcuProxyForConfig {
t: t,
}
t.Cleanup(func() {
proxy.mu.Lock()
defer proxy.mu.Unlock()
assert.Empty(t, proxy.expected)
})
return proxy
@ -83,20 +87,29 @@ func (p *mcuProxyForConfig) Expect(action string, url string, ips ...net.IP) {
})
}
func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) {
func (p *mcuProxyForConfig) addWaiter() chan struct{} {
p.t.Helper()
p.mu.Lock()
defer p.mu.Unlock()
if len(p.expected) == 0 {
return
return nil
}
waiter := make(chan struct{})
p.waiters = append(p.waiters, waiter)
p.mu.Unlock()
defer p.mu.Lock()
return waiter
}
func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) {
p.t.Helper()
waiter := p.addWaiter()
if waiter == nil {
return
}
select {
case <-ctx.Done():
assert.NoError(p.t, ctx.Err())
@ -104,6 +117,32 @@ func (p *mcuProxyForConfig) WaitForEvents(ctx context.Context) {
}
}
func (p *mcuProxyForConfig) getWaitersIfEmpty() []chan struct{} {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.expected) != 0 {
return nil
}
waiters := p.waiters
p.waiters = nil
return waiters
}
func (p *mcuProxyForConfig) getExpectedEvent() *proxyConfigEvent {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.expected) == 0 {
return nil
}
expected := p.expected[0]
p.expected = p.expected[1:]
return &expected
}
func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) {
p.t.Helper()
pc := make([]uintptr, 32)
@ -121,32 +160,24 @@ func (p *mcuProxyForConfig) checkEvent(event *proxyConfigEvent) {
}
}
p.mu.Lock()
defer p.mu.Unlock()
if len(p.expected) == 0 {
expected := p.getExpectedEvent()
if expected == nil {
assert.Fail(p.t, "no event expected", "received %+v from %s:%d", event, caller.File, caller.Line)
return
}
defer func() {
if len(p.expected) == 0 {
waiters := p.waiters
p.waiters = nil
p.mu.Unlock()
defer p.mu.Lock()
for _, ch := range waiters {
ch <- struct{}{}
}
}
}()
expected := p.expected[0]
p.expected = p.expected[1:]
if !reflect.DeepEqual(expected, *event) {
if !reflect.DeepEqual(expected, event) {
assert.Fail(p.t, "wrong event", "expected %+v, received %+v from %s:%d", expected, event, caller.File, caller.Line)
}
waiters := p.getWaitersIfEmpty()
if len(waiters) == 0 {
return
}
for _, ch := range waiters {
ch <- struct{}{}
}
}
func (p *mcuProxyForConfig) AddConnection(ignoreErrors bool, url string, ips ...net.IP) error {

View file

@ -28,7 +28,9 @@ import (
type publisherStatsCounter struct {
mu sync.Mutex
// +checklocks:mu
streamTypes map[StreamType]bool
// +checklocks:mu
subscribers map[string]bool
}

19
room.go
View file

@ -69,17 +69,23 @@ type Room struct {
events AsyncEvents
backend *Backend
// +checklocks:mu
properties json.RawMessage
closer *Closer
mu *sync.RWMutex
closer *Closer
mu *sync.RWMutex
// +checklocks:mu
sessions map[PublicSessionId]Session
// +checklocks:mu
internalSessions map[*ClientSession]bool
virtualSessions map[*VirtualSession]bool
inCallSessions map[Session]bool
roomSessionData map[PublicSessionId]*RoomSessionData
// +checklocks:mu
virtualSessions map[*VirtualSession]bool
// +checklocks:mu
inCallSessions map[Session]bool
// +checklocks:mu
roomSessionData map[PublicSessionId]*RoomSessionData
// +checklocks:mu
statsRoomSessionsCurrent *prometheus.GaugeVec
// Users currently in the room
@ -590,6 +596,7 @@ func (r *Room) PublishSessionLeft(session Session) {
}
}
// +checklocksread:r.mu
func (r *Room) getClusteredInternalSessionsRLocked() (internal map[PublicSessionId]*InternalSessionData, virtual map[PublicSessionId]*VirtualSessionData) {
if r.hub.rpcClients == nil {
return nil, nil

View file

@ -70,6 +70,7 @@ type RoomPing struct {
backend *BackendClient
capabilities *Capabilities
// +checklocks:mu
entries map[string]*pingEntries
}

View file

@ -30,9 +30,11 @@ import (
)
type BuiltinRoomSessions struct {
mu sync.RWMutex
// +checklocks:mu
sessionIdToRoomSession map[PublicSessionId]RoomSessionId
// +checklocks:mu
roomSessionToSessionid map[RoomSessionId]PublicSessionId
mu sync.RWMutex
clients *GrpcClients
}

View file

@ -94,7 +94,8 @@ func createTLSListener(addr string, certFile, keyFile string) (net.Listener, err
}
type Listeners struct {
mu sync.Mutex
mu sync.Mutex
// +checklocks:mu
listeners []net.Listener
}

View file

@ -44,7 +44,7 @@ var (
// DefaultPermissionOverrides contains permission overrides for users where
// no permissions have been set by the server. If a permission is not set in
// this map, it's assumed the user has that permission.
DefaultPermissionOverrides = map[Permission]bool{
DefaultPermissionOverrides = map[Permission]bool{ // +checklocksignore: Global readonly variable.
PERMISSION_HIDE_DISPLAYNAMES: false,
}
)

View file

@ -67,7 +67,9 @@ func (w *SingleWaiter) cancel() {
type SingleNotifier struct {
sync.Mutex
waiter *SingleWaiter
// +checklocks:Mutex
waiter *SingleWaiter
// +checklocks:Mutex
waiters map[*SingleWaiter]bool
}

View file

@ -202,7 +202,8 @@ type TestClient struct {
hub *Hub
server *httptest.Server
mu sync.Mutex
mu sync.Mutex
// +checklocks:mu
conn *websocket.Conn
localAddr net.Addr

View file

@ -94,7 +94,8 @@ type memoryThrottler struct {
getNow func() time.Time
doDelay func(context.Context, time.Duration)
mu sync.RWMutex
mu sync.RWMutex
// +checklocks:mu
clients map[string]map[string][]throttleEntry
closer *Closer

View file

@ -35,11 +35,15 @@ type TransientListener interface {
}
type TransientData struct {
mu sync.Mutex
data api.StringMap
mu sync.Mutex
// +checklocks:mu
data api.StringMap
// +checklocks:mu
listeners map[TransientListener]bool
timers map[string]*time.Timer
ttlCh chan<- struct{}
// +checklocks:mu
timers map[string]*time.Timer
// +checklocks:mu
ttlCh chan<- struct{}
}
// NewTransientData creates a new transient data container.
@ -47,6 +51,7 @@ func NewTransientData() *TransientData {
return &TransientData{}
}
// +checklocks:t.mu
func (t *TransientData) sendMessageToListener(listener TransientListener, message *ServerMessage) {
t.mu.Unlock()
defer t.mu.Lock()
@ -54,6 +59,7 @@ func (t *TransientData) sendMessageToListener(listener TransientListener, messag
listener.SendMessage(message)
}
// +checklocks:t.mu
func (t *TransientData) notifySet(key string, prev, value any) {
msg := &ServerMessage{
Type: "transient",
@ -69,6 +75,7 @@ func (t *TransientData) notifySet(key string, prev, value any) {
}
}
// +checklocks:t.mu
func (t *TransientData) notifyDeleted(key string, prev any) {
msg := &ServerMessage{
Type: "transient",
@ -112,6 +119,7 @@ func (t *TransientData) RemoveListener(listener TransientListener) {
delete(t.listeners, listener)
}
// +checklocks:t.mu
func (t *TransientData) updateTTL(key string, value any, ttl time.Duration) {
if ttl <= 0 {
delete(t.timers, key)
@ -120,6 +128,7 @@ func (t *TransientData) updateTTL(key string, value any, ttl time.Duration) {
}
}
// +checklocks:t.mu
func (t *TransientData) removeAfterTTL(key string, value any, ttl time.Duration) {
if ttl <= 0 {
return
@ -147,6 +156,7 @@ func (t *TransientData) removeAfterTTL(key string, value any, ttl time.Duration)
t.timers[key] = timer
}
// +checklocks:t.mu
func (t *TransientData) doSet(key string, value any, prev any, ttl time.Duration) {
if t.data == nil {
t.data = make(api.StringMap)
@ -210,6 +220,7 @@ func (t *TransientData) CompareAndSetTTL(key string, old, value any, ttl time.Du
return true
}
// +checklocks:t.mu
func (t *TransientData) doRemove(key string, prev any) {
delete(t.data, key)
if old, found := t.timers[key]; found {
@ -243,6 +254,7 @@ func (t *TransientData) CompareAndRemove(key string, old any) bool {
return t.compareAndRemove(key, old)
}
// +checklocks:t.mu
func (t *TransientData) compareAndRemove(key string, old any) bool {
prev, found := t.data[key]
if !found || !reflect.DeepEqual(prev, old) {

View file

@ -92,6 +92,7 @@ type MockTransientListener struct {
sending chan struct{}
done chan struct{}
// +checklocks:mu
data *TransientData
}

View file

@ -45,7 +45,7 @@ func TestVirtualSession(t *testing.T) {
backend := &Backend{
id: "compat",
}
room, err := hub.createRoom(roomId, emptyProperties, backend)
room, err := hub.CreateRoom(roomId, emptyProperties, backend)
require.NoError(err)
defer room.Close()
@ -229,7 +229,7 @@ func TestVirtualSessionActorInformation(t *testing.T) {
backend := &Backend{
id: "compat",
}
room, err := hub.createRoom(roomId, emptyProperties, backend)
room, err := hub.CreateRoom(roomId, emptyProperties, backend)
require.NoError(err)
defer room.Close()
@ -439,7 +439,7 @@ func TestVirtualSessionCustomInCall(t *testing.T) {
backend := &Backend{
id: "compat",
}
room, err := hub.createRoom(roomId, emptyProperties, backend)
room, err := hub.CreateRoom(roomId, emptyProperties, backend)
require.NoError(err)
defer room.Close()
@ -581,7 +581,7 @@ func TestVirtualSessionCleanup(t *testing.T) {
backend := &Backend{
id: "compat",
}
room, err := hub.createRoom(roomId, emptyProperties, backend)
room, err := hub.CreateRoom(roomId, emptyProperties, backend)
require.NoError(err)
defer room.Close()