From c134883138cb89ee5b552bdf24ae477e0782925f Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 15 Jun 2023 13:36:53 +0200 Subject: [PATCH] Switch to atomic types from Go 1.19 --- capabilities_test.go | 12 ++--- client.go | 17 +++---- client/main.go | 22 ++++----- clientsession.go | 24 +++++---- closer.go | 6 +-- grpc_client.go | 16 +++--- hub.go | 43 ++++++++-------- hub_test.go | 7 ++- janus_client.go | 9 ++-- mcu_janus.go | 8 ++- mcu_proxy.go | 108 ++++++++++++++++++++--------------------- mcu_test.go | 6 +-- natsclient_test.go | 7 ++- proxy/proxy_client.go | 7 ++- proxy/proxy_server.go | 42 ++++++++-------- proxy/proxy_session.go | 25 +++++----- virtualsession.go | 34 ++++++------- 17 files changed, 183 insertions(+), 210 deletions(-) diff --git a/capabilities_test.go b/capabilities_test.go index 22f653b..cb94994 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -192,9 +192,9 @@ func TestCapabilities(t *testing.T) { } func TestInvalidateCapabilities(t *testing.T) { - var called uint32 + var called atomic.Uint32 url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) { - atomic.AddUint32(&called, 1) + called.Add(1) }) ctx, cancel := context.WithTimeout(context.Background(), testTimeout) @@ -209,7 +209,7 @@ func TestInvalidateCapabilities(t *testing.T) { t.Errorf("expected direct response") } - if value := atomic.LoadUint32(&called); value != 1 { + if value := called.Load(); value != 1 { t.Errorf("expected called %d, got %d", 1, value) } @@ -224,7 +224,7 @@ func TestInvalidateCapabilities(t *testing.T) { t.Errorf("expected direct response") } - if value := atomic.LoadUint32(&called); value != 2 { + if value := called.Load(); value != 2 { t.Errorf("expected called %d, got %d", 2, value) } @@ -239,7 +239,7 @@ func TestInvalidateCapabilities(t *testing.T) { t.Errorf("expected cached response") } - if value := atomic.LoadUint32(&called); value != 2 { + if value := called.Load(); value != 2 { t.Errorf("expected called %d, got %d", 2, value) } @@ -258,7 +258,7 @@ func TestInvalidateCapabilities(t *testing.T) { t.Errorf("expected direct response") } - if value := atomic.LoadUint32(&called); value != 3 { + if value := called.Load(); value != 3 { t.Errorf("expected called %d, got %d", 3, value) } } diff --git a/client.go b/client.go index 7c37232..0fe42e9 100644 --- a/client.go +++ b/client.go @@ -30,7 +30,6 @@ import ( "sync" "sync/atomic" "time" - "unsafe" "github.com/gorilla/websocket" "github.com/mailru/easyjson" @@ -108,11 +107,11 @@ type Client struct { addr string handler ClientHandler agent string - closed uint32 + closed atomic.Int32 country *string logRTT bool - session unsafe.Pointer + session atomic.Pointer[ClientSession] mu sync.Mutex @@ -150,7 +149,7 @@ func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler Cli } func (c *Client) IsConnected() bool { - return atomic.LoadUint32(&c.closed) == 0 + return c.closed.Load() == 0 } func (c *Client) IsAuthenticated() bool { @@ -158,11 +157,11 @@ func (c *Client) IsAuthenticated() bool { } func (c *Client) GetSession() *ClientSession { - return (*ClientSession)(atomic.LoadPointer(&c.session)) + return c.session.Load() } func (c *Client) SetSession(session *ClientSession) { - atomic.StorePointer(&c.session, unsafe.Pointer(session)) + c.session.Store(session) } func (c *Client) RemoteAddr() string { @@ -188,7 +187,7 @@ func (c *Client) Country() string { } func (c *Client) Close() { - if atomic.LoadUint32(&c.closed) >= 2 { + if c.closed.Load() >= 2 { // Prevent reentrant call in case this was the second closing // step. Would otherwise deadlock in the "Once.Do" call path // through "Hub.processUnregister" (which calls "Close" again). @@ -201,7 +200,7 @@ func (c *Client) Close() { } func (c *Client) doClose() { - closed := atomic.AddUint32(&c.closed, 1) + closed := c.closed.Add(1) if closed == 1 { c.mu.Lock() defer c.mu.Unlock() @@ -329,7 +328,7 @@ func (c *Client) ReadPump() { } // Stop processing if the client was closed. - if atomic.LoadUint32(&c.closed) != 0 { + if !c.IsConnected() { bufferPool.Put(decodeBuffer) break } diff --git a/client/main.go b/client/main.go index e3d3a35..cd98f59 100644 --- a/client/main.go +++ b/client/main.go @@ -81,8 +81,8 @@ const ( ) type Stats struct { - numRecvMessages uint64 - numSentMessages uint64 + numRecvMessages atomic.Uint64 + numSentMessages atomic.Uint64 resetRecvMessages uint64 resetSentMessages uint64 @@ -90,8 +90,8 @@ type Stats struct { } func (s *Stats) reset(start time.Time) { - s.resetRecvMessages = atomic.AddUint64(&s.numRecvMessages, 0) - s.resetSentMessages = atomic.AddUint64(&s.numSentMessages, 0) + s.resetRecvMessages = s.numRecvMessages.Load() + s.resetSentMessages = s.numSentMessages.Load() s.start = start } @@ -103,9 +103,9 @@ func (s *Stats) Log() { return } - totalSentMessages := atomic.AddUint64(&s.numSentMessages, 0) + totalSentMessages := s.numSentMessages.Load() sentMessages := totalSentMessages - s.resetSentMessages - totalRecvMessages := atomic.AddUint64(&s.numRecvMessages, 0) + totalRecvMessages := s.numRecvMessages.Load() recvMessages := totalRecvMessages - s.resetRecvMessages log.Printf("Stats: sent=%d (%d/sec), recv=%d (%d/sec), delta=%d", totalSentMessages, sentMessages/perSec, @@ -125,7 +125,7 @@ type SignalingClient struct { conn *websocket.Conn stats *Stats - closed uint32 + closed atomic.Bool stopChan chan struct{} @@ -164,7 +164,7 @@ func NewSignalingClient(cookie *securecookie.SecureCookie, url string, stats *St } func (c *SignalingClient) Close() { - if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + if !c.closed.CompareAndSwap(false, true) { return } @@ -197,7 +197,7 @@ func (c *SignalingClient) Send(message *signaling.ClientMessage) { } func (c *SignalingClient) processMessage(message *signaling.ServerMessage) { - atomic.AddUint64(&c.stats.numRecvMessages, 1) + c.stats.numRecvMessages.Add(1) switch message.Type { case "hello": c.processHelloMessage(message) @@ -334,7 +334,7 @@ func (c *SignalingClient) writeInternal(message *signaling.ClientMessage) bool { } writer.Close() - atomic.AddUint64(&c.stats.numSentMessages, 1) + c.stats.numSentMessages.Add(1) return true close: @@ -383,7 +383,7 @@ func (c *SignalingClient) SendMessages(clients []*SignalingClient) { sessionIds[c] = c.PublicSessionId() } - for atomic.LoadUint32(&c.closed) == 0 { + for !c.closed.Load() { now := time.Now() sender := c diff --git a/clientsession.go b/clientsession.go index 15ff832..d3174da 100644 --- a/clientsession.go +++ b/clientsession.go @@ -31,7 +31,6 @@ import ( "sync" "sync/atomic" "time" - "unsafe" "github.com/pion/sdp/v3" ) @@ -50,9 +49,6 @@ var ( type ResponseHandlerFunc func(message *ClientMessage) bool type ClientSession struct { - roomJoinTime int64 - inCall uint32 - hub *Hub events AsyncEvents privateId string @@ -64,6 +60,7 @@ type ClientSession struct { userId string userData *json.RawMessage + inCall atomic.Uint32 supportsPermissions bool permissions map[Permission]bool @@ -76,7 +73,8 @@ type ClientSession struct { mu sync.Mutex client *Client - room unsafe.Pointer + room atomic.Pointer[Room] + roomJoinTime atomic.Int64 roomSessionId string publisherWaiters ChannelWaiters @@ -171,7 +169,7 @@ func (s *ClientSession) ClientType() string { // GetInCall is only used for internal clients. func (s *ClientSession) GetInCall() int { - return int(atomic.LoadUint32(&s.inCall)) + return int(s.inCall.Load()) } func (s *ClientSession) SetInCall(inCall int) bool { @@ -180,12 +178,12 @@ func (s *ClientSession) SetInCall(inCall int) bool { } for { - old := atomic.LoadUint32(&s.inCall) + old := s.inCall.Load() if old == uint32(inCall) { return false } - if atomic.CompareAndSwapUint32(&s.inCall, old, uint32(inCall)) { + if s.inCall.CompareAndSwap(old, uint32(inCall)) { return true } } @@ -340,11 +338,11 @@ func (s *ClientSession) IsExpired(now time.Time) bool { } func (s *ClientSession) SetRoom(room *Room) { - atomic.StorePointer(&s.room, unsafe.Pointer(room)) + s.room.Store(room) if room != nil { - atomic.StoreInt64(&s.roomJoinTime, time.Now().UnixNano()) + s.roomJoinTime.Store(time.Now().UnixNano()) } else { - atomic.StoreInt64(&s.roomJoinTime, 0) + s.roomJoinTime.Store(0) } s.seenJoinedLock.Lock() @@ -353,11 +351,11 @@ func (s *ClientSession) SetRoom(room *Room) { } func (s *ClientSession) GetRoom() *Room { - return (*Room)(atomic.LoadPointer(&s.room)) + return s.room.Load() } func (s *ClientSession) getRoomJoinTime() time.Time { - t := atomic.LoadInt64(&s.roomJoinTime) + t := s.roomJoinTime.Load() if t == 0 { return time.Time{} } diff --git a/closer.go b/closer.go index a68c850..ea00769 100644 --- a/closer.go +++ b/closer.go @@ -26,7 +26,7 @@ import ( ) type Closer struct { - closed uint32 + closed atomic.Bool C chan struct{} } @@ -37,11 +37,11 @@ func NewCloser() *Closer { } func (c *Closer) IsClosed() bool { - return atomic.LoadUint32(&c.closed) != 0 + return c.closed.Load() } func (c *Closer) Close() { - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + if c.closed.CompareAndSwap(false, true) { close(c.C) } } diff --git a/grpc_client.go b/grpc_client.go index 82b649d..1f74d62 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -51,7 +51,7 @@ const ( var ( lookupGrpcIp = net.LookupIP // can be overwritten from tests - customResolverPrefix uint64 + customResolverPrefix atomic.Uint64 ) func init() { @@ -75,12 +75,12 @@ func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl { } type GrpcClient struct { - isSelf uint32 - ip net.IP target string conn *grpc.ClientConn impl *grpcClientImpl + + isSelf atomic.Bool } type customIpResolver struct { @@ -125,7 +125,7 @@ func NewGrpcClient(target string, ip net.IP, opts ...grpc.DialOption) (*GrpcClie var conn *grpc.ClientConn var err error if ip != nil { - prefix := atomic.AddUint64(&customResolverPrefix, 1) + prefix := customResolverPrefix.Add(1) addr := ip.String() hostname := target if host, port, err := net.SplitHostPort(target); err == nil { @@ -168,15 +168,11 @@ func (c *GrpcClient) Close() error { } func (c *GrpcClient) IsSelf() bool { - return atomic.LoadUint32(&c.isSelf) != 0 + return c.isSelf.Load() } func (c *GrpcClient) SetSelf(self bool) { - if self { - atomic.StoreUint32(&c.isSelf, 1) - } else { - atomic.StoreUint32(&c.isSelf, 0) - } + c.isSelf.Store(self) } func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) { diff --git a/hub.go b/hub.go index bbcfc98..f57687b 100644 --- a/hub.go +++ b/hub.go @@ -112,9 +112,6 @@ func init() { } type Hub struct { - // 64-bit members that are accessed atomically must be 64-bit aligned. - sid uint64 - events AsyncEvents upgrader websocket.Upgrader cookie *securecookie.SecureCookie @@ -123,8 +120,8 @@ type Hub struct { welcome atomic.Value // *ServerMessage closer *Closer - readPumpActive uint32 - writePumpActive uint32 + readPumpActive atomic.Int32 + writePumpActive atomic.Int32 roomUpdated chan *BackendServerRoomRequest roomDeleted chan *BackendServerRoomRequest @@ -134,6 +131,7 @@ type Hub struct { mu sync.RWMutex ru sync.RWMutex + sid atomic.Uint64 clients map[uint64]*Client sessions map[uint64]Session rooms map[string]*Room @@ -160,7 +158,7 @@ type Hub struct { geoip *GeoLookup geoipOverrides map[*net.IPNet]string - geoipUpdating int32 + geoipUpdating atomic.Bool rpcServer *GrpcServer rpcClients *GrpcClients @@ -414,12 +412,12 @@ func (h *Hub) updateGeoDatabase() { return } - if !atomic.CompareAndSwapInt32(&h.geoipUpdating, 0, 1) { + if !h.geoipUpdating.CompareAndSwap(false, true) { // Already updating return } - defer atomic.CompareAndSwapInt32(&h.geoipUpdating, 1, 0) + defer h.geoipUpdating.Store(false) delay := time.Second for !h.closer.IsClosed() { err := h.geoip.Update() @@ -699,9 +697,9 @@ func (h *Hub) sendWelcome(client *Client) { } func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData { - sid := atomic.AddUint64(&h.sid, 1) + sid := h.sid.Add(1) for sid == 0 { - sid = atomic.AddUint64(&h.sid, 1) + sid = h.sid.Add(1) } sessionIdData := &SessionIdData{ Sid: sid, @@ -725,10 +723,6 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B return } - sid := atomic.AddUint64(&h.sid, 1) - for sid == 0 { - sid = atomic.AddUint64(&h.sid, 1) - } sessionIdData := h.newSessionIdData(backend) privateSessionId, err := h.encodeSessionId(sessionIdData, privateSessionName) if err != nil { @@ -764,7 +758,8 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B } if limit := uint32(backend.Limit()); limit > 0 && h.rpcClients != nil { - totalCount := uint32(backend.Len()) + var totalCount atomic.Uint32 + totalCount.Add(uint32(backend.Len())) var wg sync.WaitGroup ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -781,12 +776,12 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B if count > 0 { log.Printf("%d sessions connected for %s on %s", count, backend.Url(), c.Target()) - atomic.AddUint32(&totalCount, count) + totalCount.Add(count) } }(client) } wg.Wait() - if totalCount > limit { + if totalCount.Load() > limit { backend.RemoveSession(session) log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), SessionLimitExceeded) session.Close() @@ -2054,7 +2049,7 @@ func (h *Hub) isInSameCallRemote(ctx context.Context, senderSession *ClientSessi return false } - var result int32 + var result atomic.Bool var wg sync.WaitGroup rpcCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -2074,12 +2069,12 @@ func (h *Hub) isInSameCallRemote(ctx context.Context, senderSession *ClientSessi } cancel() - atomic.StoreInt32(&result, 1) + result.Store(true) }(client) } wg.Wait() - return atomic.LoadInt32(&result) != 0 + return result.Load() } func (h *Hub) isInSameCall(ctx context.Context, senderSession *ClientSession, recipientSessionId string) bool { @@ -2364,13 +2359,13 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { h.processNewClient(client) go func(h *Hub) { - atomic.AddUint32(&h.writePumpActive, 1) - defer atomic.AddUint32(&h.writePumpActive, ^uint32(0)) + h.writePumpActive.Add(1) + defer h.writePumpActive.Add(-1) client.WritePump() }(h) go func(h *Hub) { - atomic.AddUint32(&h.readPumpActive, 1) - defer atomic.AddUint32(&h.readPumpActive, ^uint32(0)) + h.readPumpActive.Add(1) + defer h.readPumpActive.Add(-1) client.ReadPump() }(h) } diff --git a/hub_test.go b/hub_test.go index 90b5f93..0958d76 100644 --- a/hub_test.go +++ b/hub_test.go @@ -42,7 +42,6 @@ import ( "reflect" "strings" "sync" - "sync/atomic" "testing" "time" @@ -279,8 +278,8 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) { h.ru.Lock() rooms := len(h.rooms) h.ru.Unlock() - readActive := atomic.LoadUint32(&h.readPumpActive) - writeActive := atomic.LoadUint32(&h.writePumpActive) + readActive := h.readPumpActive.Load() + writeActive := h.writePumpActive.Load() if clients == 0 && rooms == 0 && sessions == 0 && readActive == 0 && writeActive == 0 { break } @@ -1631,7 +1630,7 @@ func TestClientHelloResumeOtherHub(t *testing.T) { } // Simulate a restart of the hub. - atomic.StoreUint64(&hub.sid, 0) + hub.sid.Store(0) sessions := make([]Session, 0) hub.mu.Lock() for _, session := range hub.sessions { diff --git a/janus_client.go b/janus_client.go index 5dc1991..7c21b8d 100644 --- a/janus_client.go +++ b/janus_client.go @@ -221,8 +221,6 @@ func (l *dummyGatewayListener) ConnectionInterrupted() { // Gateway represents a connection to an instance of the Janus Gateway. type JanusGateway struct { - nextTransaction uint64 - listener GatewayListener // Sessions is a map of the currently active sessions to the gateway. @@ -232,8 +230,9 @@ type JanusGateway struct { // and Gateway.Unlock() methods provided by the embedded sync.Mutex. sync.Mutex - conn *websocket.Conn - transactions map[uint64]*transaction + conn *websocket.Conn + nextTransaction atomic.Uint64 + transactions map[uint64]*transaction closer *Closer @@ -328,7 +327,7 @@ func (gateway *JanusGateway) removeTransaction(id uint64) { } func (gateway *JanusGateway) send(msg map[string]interface{}, t *transaction) (uint64, error) { - id := atomic.AddUint64(&gateway.nextTransaction, 1) + id := gateway.nextTransaction.Add(1) msg["transaction"] = strconv.FormatUint(id, 10) data, err := json.Marshal(msg) if err != nil { diff --git a/mcu_janus.go b/mcu_janus.go index c5ca17f..e7a1e39 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -132,9 +132,6 @@ type clientInterface interface { } type mcuJanus struct { - // 64-bit members that are accessed atomically must be 64-bit aligned. - clientId uint64 - url string mu sync.Mutex @@ -150,6 +147,7 @@ type mcuJanus struct { muClients sync.Mutex clients map[clientInterface]bool + clientId atomic.Uint64 publishers map[string]*mcuJanusPublisher publisherCreated Notifier @@ -799,7 +797,7 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id st mcu: m, listener: listener, - id: atomic.AddUint64(&m.clientId, 1), + id: m.clientId.Add(1), session: session, roomId: roomId, sid: sid, @@ -1040,7 +1038,7 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ mcu: m, listener: listener, - id: atomic.AddUint64(&m.clientId, 1), + id: m.clientId.Add(1), roomId: pub.roomId, sid: strconv.FormatUint(handle.Id, 10), streamType: streamType, diff --git a/mcu_proxy.go b/mcu_proxy.go index a0d03fa..d675bd9 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -294,31 +294,29 @@ func (s *mcuProxySubscriber) ProcessEvent(msg *EventProxyServerMessage) { } type mcuProxyConnection struct { - // 64-bit members that are accessed atomically must be 64-bit aligned. - reconnectInterval int64 - msgId int64 - load int64 - proxy *mcuProxy rawUrl string url *url.URL ip net.IP + load atomic.Int64 mu sync.Mutex closer *Closer closedDone *Closer - closed uint32 + closed atomic.Bool conn *websocket.Conn connectedSince time.Time reconnectTimer *time.Timer - shutdownScheduled uint32 - closeScheduled uint32 - trackClose uint32 - temporary uint32 + reconnectInterval atomic.Int64 + shutdownScheduled atomic.Bool + closeScheduled atomic.Bool + trackClose atomic.Bool + temporary atomic.Bool connectedNotifier SingleNotifier + msgId atomic.Int64 helloMsgId string sessionId string country atomic.Value @@ -340,19 +338,19 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP) (*mcuProx } conn := &mcuProxyConnection{ - proxy: proxy, - rawUrl: baseUrl, - url: parsed, - ip: ip, - closer: NewCloser(), - closedDone: NewCloser(), - reconnectInterval: int64(initialReconnectInterval), - load: loadNotConnected, - callbacks: make(map[string]func(*ProxyServerMessage)), - publishers: make(map[string]*mcuProxyPublisher), - publisherIds: make(map[string]string), - subscribers: make(map[string]*mcuProxySubscriber), + proxy: proxy, + rawUrl: baseUrl, + url: parsed, + ip: ip, + closer: NewCloser(), + closedDone: NewCloser(), + callbacks: make(map[string]func(*ProxyServerMessage)), + publishers: make(map[string]*mcuProxyPublisher), + publisherIds: make(map[string]string), + subscribers: make(map[string]*mcuProxySubscriber), } + conn.reconnectInterval.Store(int64(initialReconnectInterval)) + conn.load.Store(loadNotConnected) conn.country.Store("") return conn, nil } @@ -405,7 +403,7 @@ func (c *mcuProxyConnection) GetStats() *mcuProxyConnectionStats { } func (c *mcuProxyConnection) Load() int64 { - return atomic.LoadInt64(&c.load) + return c.load.Load() } func (c *mcuProxyConnection) Country() string { @@ -413,31 +411,31 @@ func (c *mcuProxyConnection) Country() string { } func (c *mcuProxyConnection) IsTemporary() bool { - return atomic.LoadUint32(&c.temporary) != 0 + return c.temporary.Load() } func (c *mcuProxyConnection) setTemporary() { - atomic.StoreUint32(&c.temporary, 1) + c.temporary.Store(true) } func (c *mcuProxyConnection) clearTemporary() { - atomic.StoreUint32(&c.temporary, 0) + c.temporary.Store(false) } func (c *mcuProxyConnection) IsShutdownScheduled() bool { - return atomic.LoadUint32(&c.shutdownScheduled) != 0 || atomic.LoadUint32(&c.closeScheduled) != 0 + return c.shutdownScheduled.Load() || c.closeScheduled.Load() } func (c *mcuProxyConnection) readPump() { defer func() { - if atomic.LoadUint32(&c.closed) == 0 { + if !c.closed.Load() { c.scheduleReconnect() } else { c.closedDone.Close() } }() defer c.close() - defer atomic.StoreInt64(&c.load, loadNotConnected) + defer c.load.Store(loadNotConnected) c.mu.Lock() conn := c.conn @@ -539,7 +537,7 @@ func (c *mcuProxyConnection) sendClose() error { } func (c *mcuProxyConnection) stop(ctx context.Context) { - if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + if !c.closed.CompareAndSwap(false, true) { return } @@ -571,18 +569,18 @@ func (c *mcuProxyConnection) close() { if c.conn != nil { c.conn.Close() c.conn = nil - if atomic.CompareAndSwapUint32(&c.trackClose, 1, 0) { + if c.trackClose.CompareAndSwap(true, false) { statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Dec() } } } func (c *mcuProxyConnection) stopCloseIfEmpty() { - atomic.StoreUint32(&c.closeScheduled, 0) + c.closeScheduled.Store(false) } func (c *mcuProxyConnection) closeIfEmpty() bool { - atomic.StoreUint32(&c.closeScheduled, 1) + c.closeScheduled.Store(true) var total int64 c.publishersLock.RLock() @@ -620,14 +618,14 @@ func (c *mcuProxyConnection) scheduleReconnect() { return } - interval := atomic.LoadInt64(&c.reconnectInterval) + interval := c.reconnectInterval.Load() c.reconnectTimer.Reset(time.Duration(interval)) interval = interval * 2 if interval > int64(maxReconnectInterval) { interval = int64(maxReconnectInterval) } - atomic.StoreInt64(&c.reconnectInterval, interval) + c.reconnectInterval.Store(interval) } func (c *mcuProxyConnection) reconnect() { @@ -673,15 +671,15 @@ func (c *mcuProxyConnection) reconnect() { } log.Printf("Connected to %s", c) - atomic.StoreUint32(&c.closed, 0) + c.closed.Store(false) c.mu.Lock() c.connectedSince = time.Now() c.conn = conn c.mu.Unlock() - atomic.StoreInt64(&c.reconnectInterval, int64(initialReconnectInterval)) - atomic.StoreUint32(&c.shutdownScheduled, 0) + c.reconnectInterval.Store(int64(initialReconnectInterval)) + c.shutdownScheduled.Store(false) if err := c.sendHello(); err != nil { log.Printf("Could not send hello request to %s: %s", c, err) c.scheduleReconnect() @@ -723,7 +721,7 @@ func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) { } delete(c.publisherIds, publisher.id+"|"+publisher.StreamType()) - if len(c.publishers) == 0 && (atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary()) { + if len(c.publishers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) { go c.closeIfEmpty() } } @@ -740,7 +738,7 @@ func (c *mcuProxyConnection) clearPublishers() { c.publishers = make(map[string]*mcuProxyPublisher) c.publisherIds = make(map[string]string) - if atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary() { + if c.closeScheduled.Load() || c.IsTemporary() { go c.closeIfEmpty() } } @@ -754,7 +752,7 @@ func (c *mcuProxyConnection) removeSubscriber(subscriber *mcuProxySubscriber) { statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec() } - if len(c.subscribers) == 0 && (atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary()) { + if len(c.subscribers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) { go c.closeIfEmpty() } } @@ -770,7 +768,7 @@ func (c *mcuProxyConnection) clearSubscribers() { }(c.subscribers) c.subscribers = make(map[string]*mcuProxySubscriber) - if atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary() { + if c.closeScheduled.Load() || c.IsTemporary() { go c.closeIfEmpty() } } @@ -831,7 +829,7 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) { } else { log.Printf("Received session %s from %s", c.sessionId, c) } - if atomic.CompareAndSwapUint32(&c.trackClose, 0, 1) { + if c.trackClose.CompareAndSwap(false, true) { statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Inc() } @@ -902,12 +900,12 @@ func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) { if proxyDebugMessages { log.Printf("Load of %s now at %d", c, event.Load) } - atomic.StoreInt64(&c.load, event.Load) + c.load.Store(event.Load) statsProxyBackendLoadCurrent.WithLabelValues(c.url.String()).Set(float64(event.Load)) return case "shutdown-scheduled": log.Printf("Proxy %s is scheduled to shutdown", c) - atomic.StoreUint32(&c.shutdownScheduled, 1) + c.shutdownScheduled.Store(true) return } @@ -945,7 +943,7 @@ func (c *mcuProxyConnection) processBye(msg *ProxyServerMessage) { } func (c *mcuProxyConnection) sendHello() error { - c.helloMsgId = strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10) + c.helloMsgId = strconv.FormatInt(c.msgId.Add(1), 10) msg := &ProxyClientMessage{ Id: c.helloMsgId, Type: "hello", @@ -992,7 +990,7 @@ func (c *mcuProxyConnection) sendMessageLocked(msg *ProxyClientMessage) error { } func (c *mcuProxyConnection) performAsyncRequest(ctx context.Context, msg *ProxyClientMessage, callback func(err error, response *ProxyServerMessage)) { - msgId := strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10) + msgId := strconv.FormatInt(c.msgId.Add(1), 10) msg.Id = msgId c.mu.Lock() @@ -1094,10 +1092,6 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList } type mcuProxy struct { - // 64-bit members that are accessed atomically must be 64-bit aligned. - connRequests int64 - nextSort int64 - urlType string tokenId string tokenKey *rsa.PrivateKey @@ -1113,6 +1107,8 @@ type mcuProxy struct { connectionsMap map[string][]*mcuProxyConnection connectionsMu sync.RWMutex proxyTimeout time.Duration + connRequests atomic.Int64 + nextSort atomic.Int64 dnsDiscovery bool stopping chan struct{} @@ -1510,7 +1506,7 @@ func (m *mcuProxy) configureStatic(config *goconf.ConfigFile, fromReload bool) e } if changed { - atomic.StoreInt64(&m.nextSort, 0) + m.nextSort.Store(0) } } else { for u, conns := range created { @@ -1644,7 +1640,7 @@ func (m *mcuProxy) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) { m.urlToKey[info.Address] = key m.connections = append(m.connections, conn) m.connectionsMap[info.Address] = []*mcuProxyConnection{conn} - atomic.StoreInt64(&m.nextSort, 0) + m.nextSort.Store(0) } } @@ -1696,7 +1692,7 @@ func (m *mcuProxy) removeConnection(c *mcuProxyConnection) { m.connectionsMap[c.rawUrl] = conns } - atomic.StoreInt64(&m.nextSort, 0) + m.nextSort.Store(0) } } @@ -1829,8 +1825,8 @@ func (m *mcuProxy) getSortedConnections(initiator McuInitiator) []*mcuProxyConne // Connections are re-sorted every requests or // every . now := time.Now().UnixNano() - if atomic.AddInt64(&m.connRequests, 1)%connectionSortRequests == 0 || atomic.LoadInt64(&m.nextSort) <= now { - atomic.StoreInt64(&m.nextSort, now+int64(connectionSortInterval)) + if m.connRequests.Add(1)%connectionSortRequests == 0 || m.nextSort.Load() <= now { + m.nextSort.Store(now + int64(connectionSortInterval)) sorted := make(mcuProxyConnectionsList, len(connections)) copy(sorted, connections) diff --git a/mcu_test.go b/mcu_test.go index a95c880..1cc88b1 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -139,7 +139,7 @@ func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publi } type TestMCUClient struct { - closed int32 + closed atomic.Bool id string sid string @@ -159,13 +159,13 @@ func (c *TestMCUClient) StreamType() string { } func (c *TestMCUClient) Close(ctx context.Context) { - if atomic.CompareAndSwapInt32(&c.closed, 0, 1) { + if c.closed.CompareAndSwap(false, true) { log.Printf("Close MCU client %s", c.id) } } func (c *TestMCUClient) isClosed() bool { - return atomic.LoadInt32(&c.closed) != 0 + return c.closed.Load() } type TestMCUPublisher struct { diff --git a/natsclient_test.go b/natsclient_test.go index cc6cbfa..9bef7d9 100644 --- a/natsclient_test.go +++ b/natsclient_test.go @@ -63,7 +63,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { } ch := make(chan struct{}) - received := int32(0) + var received atomic.Int32 max := int32(20) ready := make(chan struct{}) quit := make(chan struct{}) @@ -73,7 +73,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { for { select { case <-dest: - total := atomic.AddInt32(&received, 1) + total := received.Add(1) if total == max { err := sub.Unsubscribe() if err != nil { @@ -98,8 +98,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) { } <-ch - r := atomic.LoadInt32(&received) - if r != max { + if r := received.Load(); r != max { t.Fatalf("Received wrong # of messages: %d vs %d", r, max) } } diff --git a/proxy/proxy_client.go b/proxy/proxy_client.go index c9c495a..dde4de8 100644 --- a/proxy/proxy_client.go +++ b/proxy/proxy_client.go @@ -24,7 +24,6 @@ package main import ( "sync/atomic" "time" - "unsafe" "github.com/gorilla/websocket" signaling "github.com/strukturag/nextcloud-spreed-signaling" @@ -35,7 +34,7 @@ type ProxyClient struct { proxy *ProxyServer - session unsafe.Pointer + session atomic.Pointer[ProxySession] } func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) { @@ -47,11 +46,11 @@ func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*Pro } func (c *ProxyClient) GetSession() *ProxySession { - return (*ProxySession)(atomic.LoadPointer(&c.session)) + return c.session.Load() } func (c *ProxyClient) SetSession(session *ProxySession) { - atomic.StorePointer(&c.session, unsafe.Pointer(session)) + c.session.Store(session) } func (c *ProxyClient) OnClosed(client *signaling.Client) { diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 75c58b1..4cd1dc9 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -82,25 +82,23 @@ var ( ) type ProxyServer struct { - // 64-bit members that are accessed atomically must be 64-bit aligned. - load int64 - version string country string url string mcu signaling.Mcu - stopped uint32 + stopped atomic.Bool + load atomic.Int64 shutdownChannel chan struct{} - shutdownScheduled uint32 + shutdownScheduled atomic.Bool upgrader websocket.Upgrader tokens ProxyTokens statsAllowedIps *signaling.AllowedIps - sid uint64 + sid atomic.Uint64 cookie *securecookie.SecureCookie sessions map[uint64]*ProxySession sessionsLock sync.RWMutex @@ -279,12 +277,12 @@ loop: for { select { case <-updateLoadTicker.C: - if atomic.LoadUint32(&s.stopped) != 0 { + if s.stopped.Load() { break loop } s.updateLoad() case <-expireSessionsTicker.C: - if atomic.LoadUint32(&s.stopped) != 0 { + if s.stopped.Load() { break loop } s.expireSessions() @@ -296,12 +294,12 @@ func (s *ProxyServer) updateLoad() { // TODO: Take maximum bandwidth of clients into account when calculating // load (screensharing requires more than regular audio/video). load := s.GetClientCount() - if load == atomic.LoadInt64(&s.load) { + if load == s.load.Load() { return } - atomic.StoreInt64(&s.load, load) - if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + s.load.Store(load) + if s.shutdownScheduled.Load() { // Server is scheduled to shutdown, no need to update clients with current load. return } @@ -349,7 +347,7 @@ func (s *ProxyServer) expireSessions() { } func (s *ProxyServer) Stop() { - if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { + if !s.stopped.CompareAndSwap(false, true) { return } @@ -364,7 +362,7 @@ func (s *ProxyServer) ShutdownChannel() <-chan struct{} { } func (s *ProxyServer) ScheduleShutdown() { - if !atomic.CompareAndSwapUint32(&s.shutdownScheduled, 0, 1) { + if !s.shutdownScheduled.CompareAndSwap(false, true) { return } @@ -449,7 +447,7 @@ func (s *ProxyServer) onMcuConnected() { } func (s *ProxyServer) onMcuDisconnected() { - if atomic.LoadUint32(&s.stopped) != 0 { + if s.stopped.Load() { // Shutting down, no need to notify. return } @@ -473,7 +471,7 @@ func (s *ProxyServer) sendCurrentLoad(session *ProxySession) { Type: "event", Event: &signaling.EventProxyServerMessage{ Type: "update-load", - Load: atomic.LoadInt64(&s.load), + Load: s.load.Load(), }, } session.sendMessage(msg) @@ -535,7 +533,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { log.Printf("Resumed session %s", session.PublicId()) session.MarkUsed() - if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + if s.shutdownScheduled.Load() { s.sendShutdownScheduled(session) } else { s.sendCurrentLoad(session) @@ -576,7 +574,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) { }, } client.SendMessage(response) - if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + if s.shutdownScheduled.Load() { s.sendShutdownScheduled(session) } else { s.sendCurrentLoad(session) @@ -610,7 +608,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s switch cmd.Type { case "create-publisher": - if atomic.LoadUint32(&s.shutdownScheduled) != 0 { + if s.shutdownScheduled.Load() { session.sendMessage(message.NewErrorServerMessage(ShutdownScheduled)) return } @@ -873,9 +871,9 @@ func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*Pro return nil, TokenExpired } - sid := atomic.AddUint64(&s.sid, 1) + sid := s.sid.Add(1) for sid == 0 { - sid = atomic.AddUint64(&s.sid, 1) + sid = s.sid.Add(1) } sessionIdData := &signaling.SessionIdData{ @@ -954,7 +952,7 @@ func (s *ProxyServer) DeleteClient(id string, client signaling.McuClient) bool { delete(s.clients, id) delete(s.clientIds, client.Id()) - if len(s.clients) == 0 && atomic.LoadUint32(&s.shutdownScheduled) != 0 { + if len(s.clients) == 0 && s.shutdownScheduled.Load() { go close(s.shutdownChannel) } return true @@ -981,7 +979,7 @@ func (s *ProxyServer) GetClientId(client signaling.McuClient) string { func (s *ProxyServer) getStats() map[string]interface{} { result := map[string]interface{}{ "sessions": s.GetSessionsCount(), - "load": atomic.LoadInt64(&s.load), + "load": s.load.Load(), "mcu": s.mcu.GetStats(), } return result diff --git a/proxy/proxy_session.go b/proxy/proxy_session.go index 4aec7fa..80445b2 100644 --- a/proxy/proxy_session.go +++ b/proxy/proxy_session.go @@ -37,12 +37,10 @@ const ( ) type ProxySession struct { - // 64-bit members that are accessed atomically must be 64-bit aligned. - lastUsed int64 - - proxy *ProxyServer - id string - sid uint64 + proxy *ProxyServer + id string + sid uint64 + lastUsed atomic.Int64 clientLock sync.Mutex client *ProxyClient @@ -58,11 +56,10 @@ type ProxySession struct { } func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { - return &ProxySession{ - proxy: proxy, - id: id, - sid: sid, - lastUsed: time.Now().UnixNano(), + result := &ProxySession{ + proxy: proxy, + id: id, + sid: sid, publishers: make(map[string]signaling.McuPublisher), publisherIds: make(map[signaling.McuPublisher]string), @@ -70,6 +67,8 @@ func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { subscribers: make(map[string]signaling.McuSubscriber), subscriberIds: make(map[signaling.McuSubscriber]string), } + result.MarkUsed() + return result } func (s *ProxySession) PublicId() string { @@ -81,7 +80,7 @@ func (s *ProxySession) Sid() uint64 { } func (s *ProxySession) LastUsed() time.Time { - lastUsed := atomic.LoadInt64(&s.lastUsed) + lastUsed := s.lastUsed.Load() return time.Unix(0, lastUsed) } @@ -92,7 +91,7 @@ func (s *ProxySession) IsExpired() bool { func (s *ProxySession) MarkUsed() { now := time.Now() - atomic.StoreInt64(&s.lastUsed, now.UnixNano()) + s.lastUsed.Store(now.UnixNano()) } func (s *ProxySession) Close() { diff --git a/virtualsession.go b/virtualsession.go index 1f16d59..8614199 100644 --- a/virtualsession.go +++ b/virtualsession.go @@ -28,7 +28,6 @@ import ( "net/url" "sync/atomic" "time" - "unsafe" ) const ( @@ -38,19 +37,18 @@ const ( ) type VirtualSession struct { - inCall uint32 - hub *Hub session *ClientSession privateId string publicId string data *SessionIdData - room unsafe.Pointer + room atomic.Pointer[Room] sessionId string userId string userData *json.RawMessage - flags uint32 + inCall atomic.Uint32 + flags atomic.Uint32 options *AddSessionOptions } @@ -69,9 +67,9 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string sessionId: msg.SessionId, userId: msg.UserId, userData: msg.User, - flags: msg.Flags, options: msg.Options, } + result.flags.Store(msg.Flags) if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil { return nil, err @@ -99,7 +97,7 @@ func (s *VirtualSession) ClientType() string { } func (s *VirtualSession) GetInCall() int { - return int(atomic.LoadUint32(&s.inCall)) + return int(s.inCall.Load()) } func (s *VirtualSession) SetInCall(inCall int) bool { @@ -108,12 +106,12 @@ func (s *VirtualSession) SetInCall(inCall int) bool { } for { - old := atomic.LoadUint32(&s.inCall) + old := s.inCall.Load() if old == uint32(inCall) { return false } - if atomic.CompareAndSwapUint32(&s.inCall, old, uint32(inCall)) { + if s.inCall.CompareAndSwap(old, uint32(inCall)) { return true } } @@ -144,11 +142,11 @@ func (s *VirtualSession) UserData() *json.RawMessage { } func (s *VirtualSession) SetRoom(room *Room) { - atomic.StorePointer(&s.room, unsafe.Pointer(room)) + s.room.Store(room) } func (s *VirtualSession) GetRoom() *Room { - return (*Room)(atomic.LoadPointer(&s.room)) + return s.room.Load() } func (s *VirtualSession) LeaveRoom(notify bool) *Room { @@ -243,13 +241,13 @@ func (s *VirtualSession) SessionId() string { func (s *VirtualSession) AddFlags(flags uint32) bool { for { - old := atomic.LoadUint32(&s.flags) + old := s.flags.Load() if old&flags == flags { // Flags already set. return false } newFlags := old | flags - if atomic.CompareAndSwapUint32(&s.flags, old, newFlags) { + if s.flags.CompareAndSwap(old, newFlags) { return true } // Another thread updated the flags while we were checking, retry. @@ -258,13 +256,13 @@ func (s *VirtualSession) AddFlags(flags uint32) bool { func (s *VirtualSession) RemoveFlags(flags uint32) bool { for { - old := atomic.LoadUint32(&s.flags) + old := s.flags.Load() if old&flags == 0 { // Flags not set. return false } newFlags := old & ^flags - if atomic.CompareAndSwapUint32(&s.flags, old, newFlags) { + if s.flags.CompareAndSwap(old, newFlags) { return true } // Another thread updated the flags while we were checking, retry. @@ -273,19 +271,19 @@ func (s *VirtualSession) RemoveFlags(flags uint32) bool { func (s *VirtualSession) SetFlags(flags uint32) bool { for { - old := atomic.LoadUint32(&s.flags) + old := s.flags.Load() if old == flags { return false } - if atomic.CompareAndSwapUint32(&s.flags, old, flags) { + if s.flags.CompareAndSwap(old, flags) { return true } } } func (s *VirtualSession) Flags() uint32 { - return atomic.LoadUint32(&s.flags) + return s.flags.Load() } func (s *VirtualSession) Options() *AddSessionOptions {