Switch to atomic types from Go 1.19

This commit is contained in:
Joachim Bauch 2023-06-15 13:36:53 +02:00
parent 2c5ad32391
commit c134883138
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
17 changed files with 183 additions and 210 deletions

View file

@ -192,9 +192,9 @@ func TestCapabilities(t *testing.T) {
} }
func TestInvalidateCapabilities(t *testing.T) { func TestInvalidateCapabilities(t *testing.T) {
var called uint32 var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) { url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) {
atomic.AddUint32(&called, 1) called.Add(1)
}) })
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
@ -209,7 +209,7 @@ func TestInvalidateCapabilities(t *testing.T) {
t.Errorf("expected direct response") t.Errorf("expected direct response")
} }
if value := atomic.LoadUint32(&called); value != 1 { if value := called.Load(); value != 1 {
t.Errorf("expected called %d, got %d", 1, value) t.Errorf("expected called %d, got %d", 1, value)
} }
@ -224,7 +224,7 @@ func TestInvalidateCapabilities(t *testing.T) {
t.Errorf("expected direct response") t.Errorf("expected direct response")
} }
if value := atomic.LoadUint32(&called); value != 2 { if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value) t.Errorf("expected called %d, got %d", 2, value)
} }
@ -239,7 +239,7 @@ func TestInvalidateCapabilities(t *testing.T) {
t.Errorf("expected cached response") t.Errorf("expected cached response")
} }
if value := atomic.LoadUint32(&called); value != 2 { if value := called.Load(); value != 2 {
t.Errorf("expected called %d, got %d", 2, value) t.Errorf("expected called %d, got %d", 2, value)
} }
@ -258,7 +258,7 @@ func TestInvalidateCapabilities(t *testing.T) {
t.Errorf("expected direct response") t.Errorf("expected direct response")
} }
if value := atomic.LoadUint32(&called); value != 3 { if value := called.Load(); value != 3 {
t.Errorf("expected called %d, got %d", 3, value) t.Errorf("expected called %d, got %d", 3, value)
} }
} }

View file

@ -30,7 +30,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/mailru/easyjson" "github.com/mailru/easyjson"
@ -108,11 +107,11 @@ type Client struct {
addr string addr string
handler ClientHandler handler ClientHandler
agent string agent string
closed uint32 closed atomic.Int32
country *string country *string
logRTT bool logRTT bool
session unsafe.Pointer session atomic.Pointer[ClientSession]
mu sync.Mutex mu sync.Mutex
@ -150,7 +149,7 @@ func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler Cli
} }
func (c *Client) IsConnected() bool { func (c *Client) IsConnected() bool {
return atomic.LoadUint32(&c.closed) == 0 return c.closed.Load() == 0
} }
func (c *Client) IsAuthenticated() bool { func (c *Client) IsAuthenticated() bool {
@ -158,11 +157,11 @@ func (c *Client) IsAuthenticated() bool {
} }
func (c *Client) GetSession() *ClientSession { func (c *Client) GetSession() *ClientSession {
return (*ClientSession)(atomic.LoadPointer(&c.session)) return c.session.Load()
} }
func (c *Client) SetSession(session *ClientSession) { func (c *Client) SetSession(session *ClientSession) {
atomic.StorePointer(&c.session, unsafe.Pointer(session)) c.session.Store(session)
} }
func (c *Client) RemoteAddr() string { func (c *Client) RemoteAddr() string {
@ -188,7 +187,7 @@ func (c *Client) Country() string {
} }
func (c *Client) Close() { func (c *Client) Close() {
if atomic.LoadUint32(&c.closed) >= 2 { if c.closed.Load() >= 2 {
// Prevent reentrant call in case this was the second closing // Prevent reentrant call in case this was the second closing
// step. Would otherwise deadlock in the "Once.Do" call path // step. Would otherwise deadlock in the "Once.Do" call path
// through "Hub.processUnregister" (which calls "Close" again). // through "Hub.processUnregister" (which calls "Close" again).
@ -201,7 +200,7 @@ func (c *Client) Close() {
} }
func (c *Client) doClose() { func (c *Client) doClose() {
closed := atomic.AddUint32(&c.closed, 1) closed := c.closed.Add(1)
if closed == 1 { if closed == 1 {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -329,7 +328,7 @@ func (c *Client) ReadPump() {
} }
// Stop processing if the client was closed. // Stop processing if the client was closed.
if atomic.LoadUint32(&c.closed) != 0 { if !c.IsConnected() {
bufferPool.Put(decodeBuffer) bufferPool.Put(decodeBuffer)
break break
} }

View file

@ -81,8 +81,8 @@ const (
) )
type Stats struct { type Stats struct {
numRecvMessages uint64 numRecvMessages atomic.Uint64
numSentMessages uint64 numSentMessages atomic.Uint64
resetRecvMessages uint64 resetRecvMessages uint64
resetSentMessages uint64 resetSentMessages uint64
@ -90,8 +90,8 @@ type Stats struct {
} }
func (s *Stats) reset(start time.Time) { func (s *Stats) reset(start time.Time) {
s.resetRecvMessages = atomic.AddUint64(&s.numRecvMessages, 0) s.resetRecvMessages = s.numRecvMessages.Load()
s.resetSentMessages = atomic.AddUint64(&s.numSentMessages, 0) s.resetSentMessages = s.numSentMessages.Load()
s.start = start s.start = start
} }
@ -103,9 +103,9 @@ func (s *Stats) Log() {
return return
} }
totalSentMessages := atomic.AddUint64(&s.numSentMessages, 0) totalSentMessages := s.numSentMessages.Load()
sentMessages := totalSentMessages - s.resetSentMessages sentMessages := totalSentMessages - s.resetSentMessages
totalRecvMessages := atomic.AddUint64(&s.numRecvMessages, 0) totalRecvMessages := s.numRecvMessages.Load()
recvMessages := totalRecvMessages - s.resetRecvMessages recvMessages := totalRecvMessages - s.resetRecvMessages
log.Printf("Stats: sent=%d (%d/sec), recv=%d (%d/sec), delta=%d", log.Printf("Stats: sent=%d (%d/sec), recv=%d (%d/sec), delta=%d",
totalSentMessages, sentMessages/perSec, totalSentMessages, sentMessages/perSec,
@ -125,7 +125,7 @@ type SignalingClient struct {
conn *websocket.Conn conn *websocket.Conn
stats *Stats stats *Stats
closed uint32 closed atomic.Bool
stopChan chan struct{} stopChan chan struct{}
@ -164,7 +164,7 @@ func NewSignalingClient(cookie *securecookie.SecureCookie, url string, stats *St
} }
func (c *SignalingClient) Close() { func (c *SignalingClient) Close() {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { if !c.closed.CompareAndSwap(false, true) {
return return
} }
@ -197,7 +197,7 @@ func (c *SignalingClient) Send(message *signaling.ClientMessage) {
} }
func (c *SignalingClient) processMessage(message *signaling.ServerMessage) { func (c *SignalingClient) processMessage(message *signaling.ServerMessage) {
atomic.AddUint64(&c.stats.numRecvMessages, 1) c.stats.numRecvMessages.Add(1)
switch message.Type { switch message.Type {
case "hello": case "hello":
c.processHelloMessage(message) c.processHelloMessage(message)
@ -334,7 +334,7 @@ func (c *SignalingClient) writeInternal(message *signaling.ClientMessage) bool {
} }
writer.Close() writer.Close()
atomic.AddUint64(&c.stats.numSentMessages, 1) c.stats.numSentMessages.Add(1)
return true return true
close: close:
@ -383,7 +383,7 @@ func (c *SignalingClient) SendMessages(clients []*SignalingClient) {
sessionIds[c] = c.PublicSessionId() sessionIds[c] = c.PublicSessionId()
} }
for atomic.LoadUint32(&c.closed) == 0 { for !c.closed.Load() {
now := time.Now() now := time.Now()
sender := c sender := c

View file

@ -31,7 +31,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"github.com/pion/sdp/v3" "github.com/pion/sdp/v3"
) )
@ -50,9 +49,6 @@ var (
type ResponseHandlerFunc func(message *ClientMessage) bool type ResponseHandlerFunc func(message *ClientMessage) bool
type ClientSession struct { type ClientSession struct {
roomJoinTime int64
inCall uint32
hub *Hub hub *Hub
events AsyncEvents events AsyncEvents
privateId string privateId string
@ -64,6 +60,7 @@ type ClientSession struct {
userId string userId string
userData *json.RawMessage userData *json.RawMessage
inCall atomic.Uint32
supportsPermissions bool supportsPermissions bool
permissions map[Permission]bool permissions map[Permission]bool
@ -76,7 +73,8 @@ type ClientSession struct {
mu sync.Mutex mu sync.Mutex
client *Client client *Client
room unsafe.Pointer room atomic.Pointer[Room]
roomJoinTime atomic.Int64
roomSessionId string roomSessionId string
publisherWaiters ChannelWaiters publisherWaiters ChannelWaiters
@ -171,7 +169,7 @@ func (s *ClientSession) ClientType() string {
// GetInCall is only used for internal clients. // GetInCall is only used for internal clients.
func (s *ClientSession) GetInCall() int { func (s *ClientSession) GetInCall() int {
return int(atomic.LoadUint32(&s.inCall)) return int(s.inCall.Load())
} }
func (s *ClientSession) SetInCall(inCall int) bool { func (s *ClientSession) SetInCall(inCall int) bool {
@ -180,12 +178,12 @@ func (s *ClientSession) SetInCall(inCall int) bool {
} }
for { for {
old := atomic.LoadUint32(&s.inCall) old := s.inCall.Load()
if old == uint32(inCall) { if old == uint32(inCall) {
return false return false
} }
if atomic.CompareAndSwapUint32(&s.inCall, old, uint32(inCall)) { if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true return true
} }
} }
@ -340,11 +338,11 @@ func (s *ClientSession) IsExpired(now time.Time) bool {
} }
func (s *ClientSession) SetRoom(room *Room) { func (s *ClientSession) SetRoom(room *Room) {
atomic.StorePointer(&s.room, unsafe.Pointer(room)) s.room.Store(room)
if room != nil { if room != nil {
atomic.StoreInt64(&s.roomJoinTime, time.Now().UnixNano()) s.roomJoinTime.Store(time.Now().UnixNano())
} else { } else {
atomic.StoreInt64(&s.roomJoinTime, 0) s.roomJoinTime.Store(0)
} }
s.seenJoinedLock.Lock() s.seenJoinedLock.Lock()
@ -353,11 +351,11 @@ func (s *ClientSession) SetRoom(room *Room) {
} }
func (s *ClientSession) GetRoom() *Room { func (s *ClientSession) GetRoom() *Room {
return (*Room)(atomic.LoadPointer(&s.room)) return s.room.Load()
} }
func (s *ClientSession) getRoomJoinTime() time.Time { func (s *ClientSession) getRoomJoinTime() time.Time {
t := atomic.LoadInt64(&s.roomJoinTime) t := s.roomJoinTime.Load()
if t == 0 { if t == 0 {
return time.Time{} return time.Time{}
} }

View file

@ -26,7 +26,7 @@ import (
) )
type Closer struct { type Closer struct {
closed uint32 closed atomic.Bool
C chan struct{} C chan struct{}
} }
@ -37,11 +37,11 @@ func NewCloser() *Closer {
} }
func (c *Closer) IsClosed() bool { func (c *Closer) IsClosed() bool {
return atomic.LoadUint32(&c.closed) != 0 return c.closed.Load()
} }
func (c *Closer) Close() { func (c *Closer) Close() {
if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { if c.closed.CompareAndSwap(false, true) {
close(c.C) close(c.C)
} }
} }

View file

@ -51,7 +51,7 @@ const (
var ( var (
lookupGrpcIp = net.LookupIP // can be overwritten from tests lookupGrpcIp = net.LookupIP // can be overwritten from tests
customResolverPrefix uint64 customResolverPrefix atomic.Uint64
) )
func init() { func init() {
@ -75,12 +75,12 @@ func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl {
} }
type GrpcClient struct { type GrpcClient struct {
isSelf uint32
ip net.IP ip net.IP
target string target string
conn *grpc.ClientConn conn *grpc.ClientConn
impl *grpcClientImpl impl *grpcClientImpl
isSelf atomic.Bool
} }
type customIpResolver struct { type customIpResolver struct {
@ -125,7 +125,7 @@ func NewGrpcClient(target string, ip net.IP, opts ...grpc.DialOption) (*GrpcClie
var conn *grpc.ClientConn var conn *grpc.ClientConn
var err error var err error
if ip != nil { if ip != nil {
prefix := atomic.AddUint64(&customResolverPrefix, 1) prefix := customResolverPrefix.Add(1)
addr := ip.String() addr := ip.String()
hostname := target hostname := target
if host, port, err := net.SplitHostPort(target); err == nil { if host, port, err := net.SplitHostPort(target); err == nil {
@ -168,15 +168,11 @@ func (c *GrpcClient) Close() error {
} }
func (c *GrpcClient) IsSelf() bool { func (c *GrpcClient) IsSelf() bool {
return atomic.LoadUint32(&c.isSelf) != 0 return c.isSelf.Load()
} }
func (c *GrpcClient) SetSelf(self bool) { func (c *GrpcClient) SetSelf(self bool) {
if self { c.isSelf.Store(self)
atomic.StoreUint32(&c.isSelf, 1)
} else {
atomic.StoreUint32(&c.isSelf, 0)
}
} }
func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) { func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) {

43
hub.go
View file

@ -112,9 +112,6 @@ func init() {
} }
type Hub struct { type Hub struct {
// 64-bit members that are accessed atomically must be 64-bit aligned.
sid uint64
events AsyncEvents events AsyncEvents
upgrader websocket.Upgrader upgrader websocket.Upgrader
cookie *securecookie.SecureCookie cookie *securecookie.SecureCookie
@ -123,8 +120,8 @@ type Hub struct {
welcome atomic.Value // *ServerMessage welcome atomic.Value // *ServerMessage
closer *Closer closer *Closer
readPumpActive uint32 readPumpActive atomic.Int32
writePumpActive uint32 writePumpActive atomic.Int32
roomUpdated chan *BackendServerRoomRequest roomUpdated chan *BackendServerRoomRequest
roomDeleted chan *BackendServerRoomRequest roomDeleted chan *BackendServerRoomRequest
@ -134,6 +131,7 @@ type Hub struct {
mu sync.RWMutex mu sync.RWMutex
ru sync.RWMutex ru sync.RWMutex
sid atomic.Uint64
clients map[uint64]*Client clients map[uint64]*Client
sessions map[uint64]Session sessions map[uint64]Session
rooms map[string]*Room rooms map[string]*Room
@ -160,7 +158,7 @@ type Hub struct {
geoip *GeoLookup geoip *GeoLookup
geoipOverrides map[*net.IPNet]string geoipOverrides map[*net.IPNet]string
geoipUpdating int32 geoipUpdating atomic.Bool
rpcServer *GrpcServer rpcServer *GrpcServer
rpcClients *GrpcClients rpcClients *GrpcClients
@ -414,12 +412,12 @@ func (h *Hub) updateGeoDatabase() {
return return
} }
if !atomic.CompareAndSwapInt32(&h.geoipUpdating, 0, 1) { if !h.geoipUpdating.CompareAndSwap(false, true) {
// Already updating // Already updating
return return
} }
defer atomic.CompareAndSwapInt32(&h.geoipUpdating, 1, 0) defer h.geoipUpdating.Store(false)
delay := time.Second delay := time.Second
for !h.closer.IsClosed() { for !h.closer.IsClosed() {
err := h.geoip.Update() err := h.geoip.Update()
@ -699,9 +697,9 @@ func (h *Hub) sendWelcome(client *Client) {
} }
func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData { func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {
sid := atomic.AddUint64(&h.sid, 1) sid := h.sid.Add(1)
for sid == 0 { for sid == 0 {
sid = atomic.AddUint64(&h.sid, 1) sid = h.sid.Add(1)
} }
sessionIdData := &SessionIdData{ sessionIdData := &SessionIdData{
Sid: sid, Sid: sid,
@ -725,10 +723,6 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
return return
} }
sid := atomic.AddUint64(&h.sid, 1)
for sid == 0 {
sid = atomic.AddUint64(&h.sid, 1)
}
sessionIdData := h.newSessionIdData(backend) sessionIdData := h.newSessionIdData(backend)
privateSessionId, err := h.encodeSessionId(sessionIdData, privateSessionName) privateSessionId, err := h.encodeSessionId(sessionIdData, privateSessionName)
if err != nil { if err != nil {
@ -764,7 +758,8 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
} }
if limit := uint32(backend.Limit()); limit > 0 && h.rpcClients != nil { if limit := uint32(backend.Limit()); limit > 0 && h.rpcClients != nil {
totalCount := uint32(backend.Len()) var totalCount atomic.Uint32
totalCount.Add(uint32(backend.Len()))
var wg sync.WaitGroup var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
@ -781,12 +776,12 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
if count > 0 { if count > 0 {
log.Printf("%d sessions connected for %s on %s", count, backend.Url(), c.Target()) log.Printf("%d sessions connected for %s on %s", count, backend.Url(), c.Target())
atomic.AddUint32(&totalCount, count) totalCount.Add(count)
} }
}(client) }(client)
} }
wg.Wait() wg.Wait()
if totalCount > limit { if totalCount.Load() > limit {
backend.RemoveSession(session) backend.RemoveSession(session)
log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), SessionLimitExceeded) log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), SessionLimitExceeded)
session.Close() session.Close()
@ -2054,7 +2049,7 @@ func (h *Hub) isInSameCallRemote(ctx context.Context, senderSession *ClientSessi
return false return false
} }
var result int32 var result atomic.Bool
var wg sync.WaitGroup var wg sync.WaitGroup
rpcCtx, cancel := context.WithCancel(ctx) rpcCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@ -2074,12 +2069,12 @@ func (h *Hub) isInSameCallRemote(ctx context.Context, senderSession *ClientSessi
} }
cancel() cancel()
atomic.StoreInt32(&result, 1) result.Store(true)
}(client) }(client)
} }
wg.Wait() wg.Wait()
return atomic.LoadInt32(&result) != 0 return result.Load()
} }
func (h *Hub) isInSameCall(ctx context.Context, senderSession *ClientSession, recipientSessionId string) bool { func (h *Hub) isInSameCall(ctx context.Context, senderSession *ClientSession, recipientSessionId string) bool {
@ -2364,13 +2359,13 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
h.processNewClient(client) h.processNewClient(client)
go func(h *Hub) { go func(h *Hub) {
atomic.AddUint32(&h.writePumpActive, 1) h.writePumpActive.Add(1)
defer atomic.AddUint32(&h.writePumpActive, ^uint32(0)) defer h.writePumpActive.Add(-1)
client.WritePump() client.WritePump()
}(h) }(h)
go func(h *Hub) { go func(h *Hub) {
atomic.AddUint32(&h.readPumpActive, 1) h.readPumpActive.Add(1)
defer atomic.AddUint32(&h.readPumpActive, ^uint32(0)) defer h.readPumpActive.Add(-1)
client.ReadPump() client.ReadPump()
}(h) }(h)
} }

View file

@ -42,7 +42,6 @@ import (
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -279,8 +278,8 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
h.ru.Lock() h.ru.Lock()
rooms := len(h.rooms) rooms := len(h.rooms)
h.ru.Unlock() h.ru.Unlock()
readActive := atomic.LoadUint32(&h.readPumpActive) readActive := h.readPumpActive.Load()
writeActive := atomic.LoadUint32(&h.writePumpActive) writeActive := h.writePumpActive.Load()
if clients == 0 && rooms == 0 && sessions == 0 && readActive == 0 && writeActive == 0 { if clients == 0 && rooms == 0 && sessions == 0 && readActive == 0 && writeActive == 0 {
break break
} }
@ -1631,7 +1630,7 @@ func TestClientHelloResumeOtherHub(t *testing.T) {
} }
// Simulate a restart of the hub. // Simulate a restart of the hub.
atomic.StoreUint64(&hub.sid, 0) hub.sid.Store(0)
sessions := make([]Session, 0) sessions := make([]Session, 0)
hub.mu.Lock() hub.mu.Lock()
for _, session := range hub.sessions { for _, session := range hub.sessions {

View file

@ -221,8 +221,6 @@ func (l *dummyGatewayListener) ConnectionInterrupted() {
// Gateway represents a connection to an instance of the Janus Gateway. // Gateway represents a connection to an instance of the Janus Gateway.
type JanusGateway struct { type JanusGateway struct {
nextTransaction uint64
listener GatewayListener listener GatewayListener
// Sessions is a map of the currently active sessions to the gateway. // Sessions is a map of the currently active sessions to the gateway.
@ -232,8 +230,9 @@ type JanusGateway struct {
// and Gateway.Unlock() methods provided by the embedded sync.Mutex. // and Gateway.Unlock() methods provided by the embedded sync.Mutex.
sync.Mutex sync.Mutex
conn *websocket.Conn conn *websocket.Conn
transactions map[uint64]*transaction nextTransaction atomic.Uint64
transactions map[uint64]*transaction
closer *Closer closer *Closer
@ -328,7 +327,7 @@ func (gateway *JanusGateway) removeTransaction(id uint64) {
} }
func (gateway *JanusGateway) send(msg map[string]interface{}, t *transaction) (uint64, error) { func (gateway *JanusGateway) send(msg map[string]interface{}, t *transaction) (uint64, error) {
id := atomic.AddUint64(&gateway.nextTransaction, 1) id := gateway.nextTransaction.Add(1)
msg["transaction"] = strconv.FormatUint(id, 10) msg["transaction"] = strconv.FormatUint(id, 10)
data, err := json.Marshal(msg) data, err := json.Marshal(msg)
if err != nil { if err != nil {

View file

@ -132,9 +132,6 @@ type clientInterface interface {
} }
type mcuJanus struct { type mcuJanus struct {
// 64-bit members that are accessed atomically must be 64-bit aligned.
clientId uint64
url string url string
mu sync.Mutex mu sync.Mutex
@ -150,6 +147,7 @@ type mcuJanus struct {
muClients sync.Mutex muClients sync.Mutex
clients map[clientInterface]bool clients map[clientInterface]bool
clientId atomic.Uint64
publishers map[string]*mcuJanusPublisher publishers map[string]*mcuJanusPublisher
publisherCreated Notifier publisherCreated Notifier
@ -799,7 +797,7 @@ func (m *mcuJanus) NewPublisher(ctx context.Context, listener McuListener, id st
mcu: m, mcu: m,
listener: listener, listener: listener,
id: atomic.AddUint64(&m.clientId, 1), id: m.clientId.Add(1),
session: session, session: session,
roomId: roomId, roomId: roomId,
sid: sid, sid: sid,
@ -1040,7 +1038,7 @@ func (m *mcuJanus) NewSubscriber(ctx context.Context, listener McuListener, publ
mcu: m, mcu: m,
listener: listener, listener: listener,
id: atomic.AddUint64(&m.clientId, 1), id: m.clientId.Add(1),
roomId: pub.roomId, roomId: pub.roomId,
sid: strconv.FormatUint(handle.Id, 10), sid: strconv.FormatUint(handle.Id, 10),
streamType: streamType, streamType: streamType,

View file

@ -294,31 +294,29 @@ func (s *mcuProxySubscriber) ProcessEvent(msg *EventProxyServerMessage) {
} }
type mcuProxyConnection struct { type mcuProxyConnection struct {
// 64-bit members that are accessed atomically must be 64-bit aligned.
reconnectInterval int64
msgId int64
load int64
proxy *mcuProxy proxy *mcuProxy
rawUrl string rawUrl string
url *url.URL url *url.URL
ip net.IP ip net.IP
load atomic.Int64
mu sync.Mutex mu sync.Mutex
closer *Closer closer *Closer
closedDone *Closer closedDone *Closer
closed uint32 closed atomic.Bool
conn *websocket.Conn conn *websocket.Conn
connectedSince time.Time connectedSince time.Time
reconnectTimer *time.Timer reconnectTimer *time.Timer
shutdownScheduled uint32 reconnectInterval atomic.Int64
closeScheduled uint32 shutdownScheduled atomic.Bool
trackClose uint32 closeScheduled atomic.Bool
temporary uint32 trackClose atomic.Bool
temporary atomic.Bool
connectedNotifier SingleNotifier connectedNotifier SingleNotifier
msgId atomic.Int64
helloMsgId string helloMsgId string
sessionId string sessionId string
country atomic.Value country atomic.Value
@ -340,19 +338,19 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP) (*mcuProx
} }
conn := &mcuProxyConnection{ conn := &mcuProxyConnection{
proxy: proxy, proxy: proxy,
rawUrl: baseUrl, rawUrl: baseUrl,
url: parsed, url: parsed,
ip: ip, ip: ip,
closer: NewCloser(), closer: NewCloser(),
closedDone: NewCloser(), closedDone: NewCloser(),
reconnectInterval: int64(initialReconnectInterval), callbacks: make(map[string]func(*ProxyServerMessage)),
load: loadNotConnected, publishers: make(map[string]*mcuProxyPublisher),
callbacks: make(map[string]func(*ProxyServerMessage)), publisherIds: make(map[string]string),
publishers: make(map[string]*mcuProxyPublisher), subscribers: make(map[string]*mcuProxySubscriber),
publisherIds: make(map[string]string),
subscribers: make(map[string]*mcuProxySubscriber),
} }
conn.reconnectInterval.Store(int64(initialReconnectInterval))
conn.load.Store(loadNotConnected)
conn.country.Store("") conn.country.Store("")
return conn, nil return conn, nil
} }
@ -405,7 +403,7 @@ func (c *mcuProxyConnection) GetStats() *mcuProxyConnectionStats {
} }
func (c *mcuProxyConnection) Load() int64 { func (c *mcuProxyConnection) Load() int64 {
return atomic.LoadInt64(&c.load) return c.load.Load()
} }
func (c *mcuProxyConnection) Country() string { func (c *mcuProxyConnection) Country() string {
@ -413,31 +411,31 @@ func (c *mcuProxyConnection) Country() string {
} }
func (c *mcuProxyConnection) IsTemporary() bool { func (c *mcuProxyConnection) IsTemporary() bool {
return atomic.LoadUint32(&c.temporary) != 0 return c.temporary.Load()
} }
func (c *mcuProxyConnection) setTemporary() { func (c *mcuProxyConnection) setTemporary() {
atomic.StoreUint32(&c.temporary, 1) c.temporary.Store(true)
} }
func (c *mcuProxyConnection) clearTemporary() { func (c *mcuProxyConnection) clearTemporary() {
atomic.StoreUint32(&c.temporary, 0) c.temporary.Store(false)
} }
func (c *mcuProxyConnection) IsShutdownScheduled() bool { func (c *mcuProxyConnection) IsShutdownScheduled() bool {
return atomic.LoadUint32(&c.shutdownScheduled) != 0 || atomic.LoadUint32(&c.closeScheduled) != 0 return c.shutdownScheduled.Load() || c.closeScheduled.Load()
} }
func (c *mcuProxyConnection) readPump() { func (c *mcuProxyConnection) readPump() {
defer func() { defer func() {
if atomic.LoadUint32(&c.closed) == 0 { if !c.closed.Load() {
c.scheduleReconnect() c.scheduleReconnect()
} else { } else {
c.closedDone.Close() c.closedDone.Close()
} }
}() }()
defer c.close() defer c.close()
defer atomic.StoreInt64(&c.load, loadNotConnected) defer c.load.Store(loadNotConnected)
c.mu.Lock() c.mu.Lock()
conn := c.conn conn := c.conn
@ -539,7 +537,7 @@ func (c *mcuProxyConnection) sendClose() error {
} }
func (c *mcuProxyConnection) stop(ctx context.Context) { func (c *mcuProxyConnection) stop(ctx context.Context) {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { if !c.closed.CompareAndSwap(false, true) {
return return
} }
@ -571,18 +569,18 @@ func (c *mcuProxyConnection) close() {
if c.conn != nil { if c.conn != nil {
c.conn.Close() c.conn.Close()
c.conn = nil c.conn = nil
if atomic.CompareAndSwapUint32(&c.trackClose, 1, 0) { if c.trackClose.CompareAndSwap(true, false) {
statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Dec() statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Dec()
} }
} }
} }
func (c *mcuProxyConnection) stopCloseIfEmpty() { func (c *mcuProxyConnection) stopCloseIfEmpty() {
atomic.StoreUint32(&c.closeScheduled, 0) c.closeScheduled.Store(false)
} }
func (c *mcuProxyConnection) closeIfEmpty() bool { func (c *mcuProxyConnection) closeIfEmpty() bool {
atomic.StoreUint32(&c.closeScheduled, 1) c.closeScheduled.Store(true)
var total int64 var total int64
c.publishersLock.RLock() c.publishersLock.RLock()
@ -620,14 +618,14 @@ func (c *mcuProxyConnection) scheduleReconnect() {
return return
} }
interval := atomic.LoadInt64(&c.reconnectInterval) interval := c.reconnectInterval.Load()
c.reconnectTimer.Reset(time.Duration(interval)) c.reconnectTimer.Reset(time.Duration(interval))
interval = interval * 2 interval = interval * 2
if interval > int64(maxReconnectInterval) { if interval > int64(maxReconnectInterval) {
interval = int64(maxReconnectInterval) interval = int64(maxReconnectInterval)
} }
atomic.StoreInt64(&c.reconnectInterval, interval) c.reconnectInterval.Store(interval)
} }
func (c *mcuProxyConnection) reconnect() { func (c *mcuProxyConnection) reconnect() {
@ -673,15 +671,15 @@ func (c *mcuProxyConnection) reconnect() {
} }
log.Printf("Connected to %s", c) log.Printf("Connected to %s", c)
atomic.StoreUint32(&c.closed, 0) c.closed.Store(false)
c.mu.Lock() c.mu.Lock()
c.connectedSince = time.Now() c.connectedSince = time.Now()
c.conn = conn c.conn = conn
c.mu.Unlock() c.mu.Unlock()
atomic.StoreInt64(&c.reconnectInterval, int64(initialReconnectInterval)) c.reconnectInterval.Store(int64(initialReconnectInterval))
atomic.StoreUint32(&c.shutdownScheduled, 0) c.shutdownScheduled.Store(false)
if err := c.sendHello(); err != nil { if err := c.sendHello(); err != nil {
log.Printf("Could not send hello request to %s: %s", c, err) log.Printf("Could not send hello request to %s: %s", c, err)
c.scheduleReconnect() c.scheduleReconnect()
@ -723,7 +721,7 @@ func (c *mcuProxyConnection) removePublisher(publisher *mcuProxyPublisher) {
} }
delete(c.publisherIds, publisher.id+"|"+publisher.StreamType()) delete(c.publisherIds, publisher.id+"|"+publisher.StreamType())
if len(c.publishers) == 0 && (atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary()) { if len(c.publishers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) {
go c.closeIfEmpty() go c.closeIfEmpty()
} }
} }
@ -740,7 +738,7 @@ func (c *mcuProxyConnection) clearPublishers() {
c.publishers = make(map[string]*mcuProxyPublisher) c.publishers = make(map[string]*mcuProxyPublisher)
c.publisherIds = make(map[string]string) c.publisherIds = make(map[string]string)
if atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary() { if c.closeScheduled.Load() || c.IsTemporary() {
go c.closeIfEmpty() go c.closeIfEmpty()
} }
} }
@ -754,7 +752,7 @@ func (c *mcuProxyConnection) removeSubscriber(subscriber *mcuProxySubscriber) {
statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec() statsSubscribersCurrent.WithLabelValues(subscriber.StreamType()).Dec()
} }
if len(c.subscribers) == 0 && (atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary()) { if len(c.subscribers) == 0 && (c.closeScheduled.Load() || c.IsTemporary()) {
go c.closeIfEmpty() go c.closeIfEmpty()
} }
} }
@ -770,7 +768,7 @@ func (c *mcuProxyConnection) clearSubscribers() {
}(c.subscribers) }(c.subscribers)
c.subscribers = make(map[string]*mcuProxySubscriber) c.subscribers = make(map[string]*mcuProxySubscriber)
if atomic.LoadUint32(&c.closeScheduled) != 0 || c.IsTemporary() { if c.closeScheduled.Load() || c.IsTemporary() {
go c.closeIfEmpty() go c.closeIfEmpty()
} }
} }
@ -831,7 +829,7 @@ func (c *mcuProxyConnection) processMessage(msg *ProxyServerMessage) {
} else { } else {
log.Printf("Received session %s from %s", c.sessionId, c) log.Printf("Received session %s from %s", c.sessionId, c)
} }
if atomic.CompareAndSwapUint32(&c.trackClose, 0, 1) { if c.trackClose.CompareAndSwap(false, true) {
statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Inc() statsConnectedProxyBackendsCurrent.WithLabelValues(c.Country()).Inc()
} }
@ -902,12 +900,12 @@ func (c *mcuProxyConnection) processEvent(msg *ProxyServerMessage) {
if proxyDebugMessages { if proxyDebugMessages {
log.Printf("Load of %s now at %d", c, event.Load) log.Printf("Load of %s now at %d", c, event.Load)
} }
atomic.StoreInt64(&c.load, event.Load) c.load.Store(event.Load)
statsProxyBackendLoadCurrent.WithLabelValues(c.url.String()).Set(float64(event.Load)) statsProxyBackendLoadCurrent.WithLabelValues(c.url.String()).Set(float64(event.Load))
return return
case "shutdown-scheduled": case "shutdown-scheduled":
log.Printf("Proxy %s is scheduled to shutdown", c) log.Printf("Proxy %s is scheduled to shutdown", c)
atomic.StoreUint32(&c.shutdownScheduled, 1) c.shutdownScheduled.Store(true)
return return
} }
@ -945,7 +943,7 @@ func (c *mcuProxyConnection) processBye(msg *ProxyServerMessage) {
} }
func (c *mcuProxyConnection) sendHello() error { func (c *mcuProxyConnection) sendHello() error {
c.helloMsgId = strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10) c.helloMsgId = strconv.FormatInt(c.msgId.Add(1), 10)
msg := &ProxyClientMessage{ msg := &ProxyClientMessage{
Id: c.helloMsgId, Id: c.helloMsgId,
Type: "hello", Type: "hello",
@ -992,7 +990,7 @@ func (c *mcuProxyConnection) sendMessageLocked(msg *ProxyClientMessage) error {
} }
func (c *mcuProxyConnection) performAsyncRequest(ctx context.Context, msg *ProxyClientMessage, callback func(err error, response *ProxyServerMessage)) { func (c *mcuProxyConnection) performAsyncRequest(ctx context.Context, msg *ProxyClientMessage, callback func(err error, response *ProxyServerMessage)) {
msgId := strconv.FormatInt(atomic.AddInt64(&c.msgId, 1), 10) msgId := strconv.FormatInt(c.msgId.Add(1), 10)
msg.Id = msgId msg.Id = msgId
c.mu.Lock() c.mu.Lock()
@ -1094,10 +1092,6 @@ func (c *mcuProxyConnection) newSubscriber(ctx context.Context, listener McuList
} }
type mcuProxy struct { type mcuProxy struct {
// 64-bit members that are accessed atomically must be 64-bit aligned.
connRequests int64
nextSort int64
urlType string urlType string
tokenId string tokenId string
tokenKey *rsa.PrivateKey tokenKey *rsa.PrivateKey
@ -1113,6 +1107,8 @@ type mcuProxy struct {
connectionsMap map[string][]*mcuProxyConnection connectionsMap map[string][]*mcuProxyConnection
connectionsMu sync.RWMutex connectionsMu sync.RWMutex
proxyTimeout time.Duration proxyTimeout time.Duration
connRequests atomic.Int64
nextSort atomic.Int64
dnsDiscovery bool dnsDiscovery bool
stopping chan struct{} stopping chan struct{}
@ -1510,7 +1506,7 @@ func (m *mcuProxy) configureStatic(config *goconf.ConfigFile, fromReload bool) e
} }
if changed { if changed {
atomic.StoreInt64(&m.nextSort, 0) m.nextSort.Store(0)
} }
} else { } else {
for u, conns := range created { for u, conns := range created {
@ -1644,7 +1640,7 @@ func (m *mcuProxy) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) {
m.urlToKey[info.Address] = key m.urlToKey[info.Address] = key
m.connections = append(m.connections, conn) m.connections = append(m.connections, conn)
m.connectionsMap[info.Address] = []*mcuProxyConnection{conn} m.connectionsMap[info.Address] = []*mcuProxyConnection{conn}
atomic.StoreInt64(&m.nextSort, 0) m.nextSort.Store(0)
} }
} }
@ -1696,7 +1692,7 @@ func (m *mcuProxy) removeConnection(c *mcuProxyConnection) {
m.connectionsMap[c.rawUrl] = conns m.connectionsMap[c.rawUrl] = conns
} }
atomic.StoreInt64(&m.nextSort, 0) m.nextSort.Store(0)
} }
} }
@ -1829,8 +1825,8 @@ func (m *mcuProxy) getSortedConnections(initiator McuInitiator) []*mcuProxyConne
// Connections are re-sorted every <connectionSortRequests> requests or // Connections are re-sorted every <connectionSortRequests> requests or
// every <connectionSortInterval>. // every <connectionSortInterval>.
now := time.Now().UnixNano() now := time.Now().UnixNano()
if atomic.AddInt64(&m.connRequests, 1)%connectionSortRequests == 0 || atomic.LoadInt64(&m.nextSort) <= now { if m.connRequests.Add(1)%connectionSortRequests == 0 || m.nextSort.Load() <= now {
atomic.StoreInt64(&m.nextSort, now+int64(connectionSortInterval)) m.nextSort.Store(now + int64(connectionSortInterval))
sorted := make(mcuProxyConnectionsList, len(connections)) sorted := make(mcuProxyConnectionsList, len(connections))
copy(sorted, connections) copy(sorted, connections)

View file

@ -139,7 +139,7 @@ func (m *TestMCU) NewSubscriber(ctx context.Context, listener McuListener, publi
} }
type TestMCUClient struct { type TestMCUClient struct {
closed int32 closed atomic.Bool
id string id string
sid string sid string
@ -159,13 +159,13 @@ func (c *TestMCUClient) StreamType() string {
} }
func (c *TestMCUClient) Close(ctx context.Context) { func (c *TestMCUClient) Close(ctx context.Context) {
if atomic.CompareAndSwapInt32(&c.closed, 0, 1) { if c.closed.CompareAndSwap(false, true) {
log.Printf("Close MCU client %s", c.id) log.Printf("Close MCU client %s", c.id)
} }
} }
func (c *TestMCUClient) isClosed() bool { func (c *TestMCUClient) isClosed() bool {
return atomic.LoadInt32(&c.closed) != 0 return c.closed.Load()
} }
type TestMCUPublisher struct { type TestMCUPublisher struct {

View file

@ -63,7 +63,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
} }
ch := make(chan struct{}) ch := make(chan struct{})
received := int32(0) var received atomic.Int32
max := int32(20) max := int32(20)
ready := make(chan struct{}) ready := make(chan struct{})
quit := make(chan struct{}) quit := make(chan struct{})
@ -73,7 +73,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
for { for {
select { select {
case <-dest: case <-dest:
total := atomic.AddInt32(&received, 1) total := received.Add(1)
if total == max { if total == max {
err := sub.Unsubscribe() err := sub.Unsubscribe()
if err != nil { if err != nil {
@ -98,8 +98,7 @@ func testNatsClient_Subscribe(t *testing.T, client NatsClient) {
} }
<-ch <-ch
r := atomic.LoadInt32(&received) if r := received.Load(); r != max {
if r != max {
t.Fatalf("Received wrong # of messages: %d vs %d", r, max) t.Fatalf("Received wrong # of messages: %d vs %d", r, max)
} }
} }

View file

@ -24,7 +24,6 @@ package main
import ( import (
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
signaling "github.com/strukturag/nextcloud-spreed-signaling" signaling "github.com/strukturag/nextcloud-spreed-signaling"
@ -35,7 +34,7 @@ type ProxyClient struct {
proxy *ProxyServer proxy *ProxyServer
session unsafe.Pointer session atomic.Pointer[ProxySession]
} }
func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) { func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*ProxyClient, error) {
@ -47,11 +46,11 @@ func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*Pro
} }
func (c *ProxyClient) GetSession() *ProxySession { func (c *ProxyClient) GetSession() *ProxySession {
return (*ProxySession)(atomic.LoadPointer(&c.session)) return c.session.Load()
} }
func (c *ProxyClient) SetSession(session *ProxySession) { func (c *ProxyClient) SetSession(session *ProxySession) {
atomic.StorePointer(&c.session, unsafe.Pointer(session)) c.session.Store(session)
} }
func (c *ProxyClient) OnClosed(client *signaling.Client) { func (c *ProxyClient) OnClosed(client *signaling.Client) {

View file

@ -82,25 +82,23 @@ var (
) )
type ProxyServer struct { type ProxyServer struct {
// 64-bit members that are accessed atomically must be 64-bit aligned.
load int64
version string version string
country string country string
url string url string
mcu signaling.Mcu mcu signaling.Mcu
stopped uint32 stopped atomic.Bool
load atomic.Int64
shutdownChannel chan struct{} shutdownChannel chan struct{}
shutdownScheduled uint32 shutdownScheduled atomic.Bool
upgrader websocket.Upgrader upgrader websocket.Upgrader
tokens ProxyTokens tokens ProxyTokens
statsAllowedIps *signaling.AllowedIps statsAllowedIps *signaling.AllowedIps
sid uint64 sid atomic.Uint64
cookie *securecookie.SecureCookie cookie *securecookie.SecureCookie
sessions map[uint64]*ProxySession sessions map[uint64]*ProxySession
sessionsLock sync.RWMutex sessionsLock sync.RWMutex
@ -279,12 +277,12 @@ loop:
for { for {
select { select {
case <-updateLoadTicker.C: case <-updateLoadTicker.C:
if atomic.LoadUint32(&s.stopped) != 0 { if s.stopped.Load() {
break loop break loop
} }
s.updateLoad() s.updateLoad()
case <-expireSessionsTicker.C: case <-expireSessionsTicker.C:
if atomic.LoadUint32(&s.stopped) != 0 { if s.stopped.Load() {
break loop break loop
} }
s.expireSessions() s.expireSessions()
@ -296,12 +294,12 @@ func (s *ProxyServer) updateLoad() {
// TODO: Take maximum bandwidth of clients into account when calculating // TODO: Take maximum bandwidth of clients into account when calculating
// load (screensharing requires more than regular audio/video). // load (screensharing requires more than regular audio/video).
load := s.GetClientCount() load := s.GetClientCount()
if load == atomic.LoadInt64(&s.load) { if load == s.load.Load() {
return return
} }
atomic.StoreInt64(&s.load, load) s.load.Store(load)
if atomic.LoadUint32(&s.shutdownScheduled) != 0 { if s.shutdownScheduled.Load() {
// Server is scheduled to shutdown, no need to update clients with current load. // Server is scheduled to shutdown, no need to update clients with current load.
return return
} }
@ -349,7 +347,7 @@ func (s *ProxyServer) expireSessions() {
} }
func (s *ProxyServer) Stop() { func (s *ProxyServer) Stop() {
if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { if !s.stopped.CompareAndSwap(false, true) {
return return
} }
@ -364,7 +362,7 @@ func (s *ProxyServer) ShutdownChannel() <-chan struct{} {
} }
func (s *ProxyServer) ScheduleShutdown() { func (s *ProxyServer) ScheduleShutdown() {
if !atomic.CompareAndSwapUint32(&s.shutdownScheduled, 0, 1) { if !s.shutdownScheduled.CompareAndSwap(false, true) {
return return
} }
@ -449,7 +447,7 @@ func (s *ProxyServer) onMcuConnected() {
} }
func (s *ProxyServer) onMcuDisconnected() { func (s *ProxyServer) onMcuDisconnected() {
if atomic.LoadUint32(&s.stopped) != 0 { if s.stopped.Load() {
// Shutting down, no need to notify. // Shutting down, no need to notify.
return return
} }
@ -473,7 +471,7 @@ func (s *ProxyServer) sendCurrentLoad(session *ProxySession) {
Type: "event", Type: "event",
Event: &signaling.EventProxyServerMessage{ Event: &signaling.EventProxyServerMessage{
Type: "update-load", Type: "update-load",
Load: atomic.LoadInt64(&s.load), Load: s.load.Load(),
}, },
} }
session.sendMessage(msg) session.sendMessage(msg)
@ -535,7 +533,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) {
log.Printf("Resumed session %s", session.PublicId()) log.Printf("Resumed session %s", session.PublicId())
session.MarkUsed() session.MarkUsed()
if atomic.LoadUint32(&s.shutdownScheduled) != 0 { if s.shutdownScheduled.Load() {
s.sendShutdownScheduled(session) s.sendShutdownScheduled(session)
} else { } else {
s.sendCurrentLoad(session) s.sendCurrentLoad(session)
@ -576,7 +574,7 @@ func (s *ProxyServer) processMessage(client *ProxyClient, data []byte) {
}, },
} }
client.SendMessage(response) client.SendMessage(response)
if atomic.LoadUint32(&s.shutdownScheduled) != 0 { if s.shutdownScheduled.Load() {
s.sendShutdownScheduled(session) s.sendShutdownScheduled(session)
} else { } else {
s.sendCurrentLoad(session) s.sendCurrentLoad(session)
@ -610,7 +608,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s
switch cmd.Type { switch cmd.Type {
case "create-publisher": case "create-publisher":
if atomic.LoadUint32(&s.shutdownScheduled) != 0 { if s.shutdownScheduled.Load() {
session.sendMessage(message.NewErrorServerMessage(ShutdownScheduled)) session.sendMessage(message.NewErrorServerMessage(ShutdownScheduled))
return return
} }
@ -873,9 +871,9 @@ func (s *ProxyServer) NewSession(hello *signaling.HelloProxyClientMessage) (*Pro
return nil, TokenExpired return nil, TokenExpired
} }
sid := atomic.AddUint64(&s.sid, 1) sid := s.sid.Add(1)
for sid == 0 { for sid == 0 {
sid = atomic.AddUint64(&s.sid, 1) sid = s.sid.Add(1)
} }
sessionIdData := &signaling.SessionIdData{ sessionIdData := &signaling.SessionIdData{
@ -954,7 +952,7 @@ func (s *ProxyServer) DeleteClient(id string, client signaling.McuClient) bool {
delete(s.clients, id) delete(s.clients, id)
delete(s.clientIds, client.Id()) delete(s.clientIds, client.Id())
if len(s.clients) == 0 && atomic.LoadUint32(&s.shutdownScheduled) != 0 { if len(s.clients) == 0 && s.shutdownScheduled.Load() {
go close(s.shutdownChannel) go close(s.shutdownChannel)
} }
return true return true
@ -981,7 +979,7 @@ func (s *ProxyServer) GetClientId(client signaling.McuClient) string {
func (s *ProxyServer) getStats() map[string]interface{} { func (s *ProxyServer) getStats() map[string]interface{} {
result := map[string]interface{}{ result := map[string]interface{}{
"sessions": s.GetSessionsCount(), "sessions": s.GetSessionsCount(),
"load": atomic.LoadInt64(&s.load), "load": s.load.Load(),
"mcu": s.mcu.GetStats(), "mcu": s.mcu.GetStats(),
} }
return result return result

View file

@ -37,12 +37,10 @@ const (
) )
type ProxySession struct { type ProxySession struct {
// 64-bit members that are accessed atomically must be 64-bit aligned. proxy *ProxyServer
lastUsed int64 id string
sid uint64
proxy *ProxyServer lastUsed atomic.Int64
id string
sid uint64
clientLock sync.Mutex clientLock sync.Mutex
client *ProxyClient client *ProxyClient
@ -58,11 +56,10 @@ type ProxySession struct {
} }
func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession {
return &ProxySession{ result := &ProxySession{
proxy: proxy, proxy: proxy,
id: id, id: id,
sid: sid, sid: sid,
lastUsed: time.Now().UnixNano(),
publishers: make(map[string]signaling.McuPublisher), publishers: make(map[string]signaling.McuPublisher),
publisherIds: make(map[signaling.McuPublisher]string), publisherIds: make(map[signaling.McuPublisher]string),
@ -70,6 +67,8 @@ func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession {
subscribers: make(map[string]signaling.McuSubscriber), subscribers: make(map[string]signaling.McuSubscriber),
subscriberIds: make(map[signaling.McuSubscriber]string), subscriberIds: make(map[signaling.McuSubscriber]string),
} }
result.MarkUsed()
return result
} }
func (s *ProxySession) PublicId() string { func (s *ProxySession) PublicId() string {
@ -81,7 +80,7 @@ func (s *ProxySession) Sid() uint64 {
} }
func (s *ProxySession) LastUsed() time.Time { func (s *ProxySession) LastUsed() time.Time {
lastUsed := atomic.LoadInt64(&s.lastUsed) lastUsed := s.lastUsed.Load()
return time.Unix(0, lastUsed) return time.Unix(0, lastUsed)
} }
@ -92,7 +91,7 @@ func (s *ProxySession) IsExpired() bool {
func (s *ProxySession) MarkUsed() { func (s *ProxySession) MarkUsed() {
now := time.Now() now := time.Now()
atomic.StoreInt64(&s.lastUsed, now.UnixNano()) s.lastUsed.Store(now.UnixNano())
} }
func (s *ProxySession) Close() { func (s *ProxySession) Close() {

View file

@ -28,7 +28,6 @@ import (
"net/url" "net/url"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
) )
const ( const (
@ -38,19 +37,18 @@ const (
) )
type VirtualSession struct { type VirtualSession struct {
inCall uint32
hub *Hub hub *Hub
session *ClientSession session *ClientSession
privateId string privateId string
publicId string publicId string
data *SessionIdData data *SessionIdData
room unsafe.Pointer room atomic.Pointer[Room]
sessionId string sessionId string
userId string userId string
userData *json.RawMessage userData *json.RawMessage
flags uint32 inCall atomic.Uint32
flags atomic.Uint32
options *AddSessionOptions options *AddSessionOptions
} }
@ -69,9 +67,9 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
sessionId: msg.SessionId, sessionId: msg.SessionId,
userId: msg.UserId, userId: msg.UserId,
userData: msg.User, userData: msg.User,
flags: msg.Flags,
options: msg.Options, options: msg.Options,
} }
result.flags.Store(msg.Flags)
if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil { if err := session.events.RegisterSessionListener(publicId, session.Backend(), result); err != nil {
return nil, err return nil, err
@ -99,7 +97,7 @@ func (s *VirtualSession) ClientType() string {
} }
func (s *VirtualSession) GetInCall() int { func (s *VirtualSession) GetInCall() int {
return int(atomic.LoadUint32(&s.inCall)) return int(s.inCall.Load())
} }
func (s *VirtualSession) SetInCall(inCall int) bool { func (s *VirtualSession) SetInCall(inCall int) bool {
@ -108,12 +106,12 @@ func (s *VirtualSession) SetInCall(inCall int) bool {
} }
for { for {
old := atomic.LoadUint32(&s.inCall) old := s.inCall.Load()
if old == uint32(inCall) { if old == uint32(inCall) {
return false return false
} }
if atomic.CompareAndSwapUint32(&s.inCall, old, uint32(inCall)) { if s.inCall.CompareAndSwap(old, uint32(inCall)) {
return true return true
} }
} }
@ -144,11 +142,11 @@ func (s *VirtualSession) UserData() *json.RawMessage {
} }
func (s *VirtualSession) SetRoom(room *Room) { func (s *VirtualSession) SetRoom(room *Room) {
atomic.StorePointer(&s.room, unsafe.Pointer(room)) s.room.Store(room)
} }
func (s *VirtualSession) GetRoom() *Room { func (s *VirtualSession) GetRoom() *Room {
return (*Room)(atomic.LoadPointer(&s.room)) return s.room.Load()
} }
func (s *VirtualSession) LeaveRoom(notify bool) *Room { func (s *VirtualSession) LeaveRoom(notify bool) *Room {
@ -243,13 +241,13 @@ func (s *VirtualSession) SessionId() string {
func (s *VirtualSession) AddFlags(flags uint32) bool { func (s *VirtualSession) AddFlags(flags uint32) bool {
for { for {
old := atomic.LoadUint32(&s.flags) old := s.flags.Load()
if old&flags == flags { if old&flags == flags {
// Flags already set. // Flags already set.
return false return false
} }
newFlags := old | flags newFlags := old | flags
if atomic.CompareAndSwapUint32(&s.flags, old, newFlags) { if s.flags.CompareAndSwap(old, newFlags) {
return true return true
} }
// Another thread updated the flags while we were checking, retry. // Another thread updated the flags while we were checking, retry.
@ -258,13 +256,13 @@ func (s *VirtualSession) AddFlags(flags uint32) bool {
func (s *VirtualSession) RemoveFlags(flags uint32) bool { func (s *VirtualSession) RemoveFlags(flags uint32) bool {
for { for {
old := atomic.LoadUint32(&s.flags) old := s.flags.Load()
if old&flags == 0 { if old&flags == 0 {
// Flags not set. // Flags not set.
return false return false
} }
newFlags := old & ^flags newFlags := old & ^flags
if atomic.CompareAndSwapUint32(&s.flags, old, newFlags) { if s.flags.CompareAndSwap(old, newFlags) {
return true return true
} }
// Another thread updated the flags while we were checking, retry. // Another thread updated the flags while we were checking, retry.
@ -273,19 +271,19 @@ func (s *VirtualSession) RemoveFlags(flags uint32) bool {
func (s *VirtualSession) SetFlags(flags uint32) bool { func (s *VirtualSession) SetFlags(flags uint32) bool {
for { for {
old := atomic.LoadUint32(&s.flags) old := s.flags.Load()
if old == flags { if old == flags {
return false return false
} }
if atomic.CompareAndSwapUint32(&s.flags, old, flags) { if s.flags.CompareAndSwap(old, flags) {
return true return true
} }
} }
} }
func (s *VirtualSession) Flags() uint32 { func (s *VirtualSession) Flags() uint32 {
return atomic.LoadUint32(&s.flags) return s.flags.Load()
} }
func (s *VirtualSession) Options() *AddSessionOptions { func (s *VirtualSession) Options() *AddSessionOptions {