Fix race condition where closeOnLeave could be set after leave was received.

With that also add checklocks annotations to federation client.
This commit is contained in:
Joachim Bauch 2025-11-04 15:46:19 +01:00
commit 3fd89f7113
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
4 changed files with 76 additions and 26 deletions

View file

@ -87,20 +87,26 @@ type FederationClient struct {
changeRoomId atomic.Bool
federation atomic.Pointer[RoomFederationMessage]
mu sync.Mutex
dialer *websocket.Dialer
url string
conn *websocket.Conn
closer *Closer
mu sync.Mutex
dialer *websocket.Dialer
url string
// +checklocks:mu
conn *websocket.Conn
closer *Closer
// +checklocks:mu
reconnectDelay time.Duration
reconnecting bool
reconnectFunc *time.Timer
reconnecting atomic.Bool
// +checklocks:mu
reconnectFunc *time.Timer
helloMu sync.Mutex
helloMu sync.Mutex
// +checklocks:helloMu
helloMsgId string
helloAuth *FederationAuthParams
resumeId PrivateSessionId
hello atomic.Pointer[HelloServerMessage]
// +checklocks:helloMu
helloAuth *FederationAuthParams
// +checklocks:helloMu
resumeId PrivateSessionId
hello atomic.Pointer[HelloServerMessage]
// +checklocks:helloMu
pendingMessages []*ClientMessage
@ -167,6 +173,17 @@ func NewFederationClient(ctx context.Context, hub *Hub, session *ClientSession,
return result, nil
}
func (c *FederationClient) LocalAddr() net.Addr {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return nil
}
return c.conn.LocalAddr()
}
func (c *FederationClient) URL() string {
return c.federation.Load().parsedSignalingUrl.String()
}
@ -255,16 +272,20 @@ func (c *FederationClient) Leave(message *ClientMessage) error {
}
}
if err := c.sendMessageLocked(message); err != nil && !errors.Is(err, websocket.ErrCloseSent) {
return err
c.closeOnLeave.Store(true)
if err := c.sendMessageLocked(message); err != nil {
c.closeOnLeave.Store(false)
if !errors.Is(err, websocket.ErrCloseSent) {
return err
}
}
c.closeOnLeave.Store(true)
return nil
}
func (c *FederationClient) Close() {
c.closer.Close()
c.hub.removeFederationClient(c)
c.mu.Lock()
defer c.mu.Unlock()
@ -272,6 +293,7 @@ func (c *FederationClient) Close() {
c.closeConnection(true)
}
// +checklocks:c.mu
func (c *FederationClient) closeConnection(withBye bool) {
if c.conn == nil {
return
@ -311,8 +333,9 @@ func (c *FederationClient) scheduleReconnect() {
c.scheduleReconnectLocked()
}
// +checklocks:c.mu
func (c *FederationClient) scheduleReconnectLocked() {
c.reconnecting = true
c.reconnecting.Store(true)
if c.hello.Swap(nil) != nil {
c.session.SendMessage(&ServerMessage{
Type: "event",
@ -454,6 +477,7 @@ func (c *FederationClient) sendHello(auth *FederationAuthParams) error {
return c.sendHelloLocked(auth)
}
// +checklocks:c.helloMu
func (c *FederationClient) sendHelloLocked(auth *FederationAuthParams) error {
c.helloMsgId = newRandomString(8)
@ -539,7 +563,7 @@ func (c *FederationClient) processHello(msg *ServerMessage) {
c.hello.Store(msg.Hello)
if c.resumeId == "" {
c.resumeId = msg.Hello.ResumeId
if c.reconnecting {
if c.reconnecting.Load() {
c.session.SendMessage(&ServerMessage{
Type: "event",
Event: &EventServerMessage{
@ -941,6 +965,7 @@ func (c *FederationClient) deferMessage(message *ClientMessage) {
}
}
// +checklocks:c.mu
func (c *FederationClient) sendMessageLocked(message *ClientMessage) error {
if c.conn == nil {
if message.Type != "room" {

View file

@ -250,12 +250,14 @@ func Test_Federation(t *testing.T) {
// Client1 will receive the updated "remoteSessionId"
if message, ok := client1.RunUntilMessage(ctx); ok {
client1.checkSingleMessageJoined(message)
evt := message.Event.Join[0]
remoteSessionId = evt.SessionId
assert.NotEqual(hello2.Hello.SessionId, remoteSessionId)
assert.Equal(testDefaultUserId+"2", evt.UserId)
assert.True(evt.Federated)
assert.Equal(features2, evt.Features)
if assert.Len(message.Event.Join, 1, "invalid message received: %+v", message) {
evt := message.Event.Join[0]
remoteSessionId = evt.SessionId
assert.NotEqual(hello2.Hello.SessionId, remoteSessionId)
assert.Equal(testDefaultUserId+"2", evt.UserId)
assert.True(evt.Federated)
assert.Equal(features2, evt.Features)
}
}
client2.RunUntilJoined(ctx, hello1.Hello, hello2.Hello)
@ -659,7 +661,7 @@ func Test_FederationChangeRoom(t *testing.T) {
session2 := hub2.GetSessionByPublicId(hello2.Hello.SessionId).(*ClientSession)
fed := session2.GetFederationClient()
require.NotNil(fed)
localAddr := fed.conn.LocalAddr()
localAddr := fed.LocalAddr()
// The client1 will see the remote session id for client2.
var remoteSessionId PublicSessionId
@ -701,7 +703,7 @@ func Test_FederationChangeRoom(t *testing.T) {
fed2 := session2.GetFederationClient()
require.NotNil(fed2)
localAddr2 := fed2.conn.LocalAddr()
localAddr2 := fed2.LocalAddr()
assert.Equal(localAddr, localAddr2)
}

13
hub.go
View file

@ -195,6 +195,8 @@ type Hub struct {
remoteSessions map[*RemoteSession]bool
// +checklocks:mu
federatedSessions map[*ClientSession]bool
// +checklocks:mu
federationClients map[*FederationClient]bool
backendTimeout time.Duration
backend *BackendClient
@ -387,6 +389,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
dialoutSessions: make(map[*ClientSession]bool),
remoteSessions: make(map[*RemoteSession]bool),
federatedSessions: make(map[*ClientSession]bool),
federationClients: make(map[*FederationClient]bool),
backendTimeout: backendTimeout,
backend: backend,
@ -835,6 +838,13 @@ func (h *Hub) removeSession(session Session) (removed bool) {
return
}
func (h *Hub) removeFederationClient(client *FederationClient) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.federationClients, client)
}
// +checklocksread:h.mu
func (h *Hub) hasSessionsLocked(withInternal bool) bool {
if withInternal {
@ -1797,8 +1807,9 @@ func (h *Hub) processRoom(sess Session, message *ClientMessage) {
}
h.mu.Lock()
defer h.mu.Unlock()
h.federatedSessions[session] = true
h.mu.Unlock()
h.federationClients[client] = true
return
}

View file

@ -287,6 +287,7 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
sessions := len(h.sessions)
remoteSessions := len(h.remoteSessions)
federatedSessions := len(h.federatedSessions)
federationClients := len(h.federatedSessions)
h.mu.Unlock()
h.ru.Lock()
rooms := len(h.rooms)
@ -298,6 +299,7 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
sessions == 0 &&
remoteSessions == 0 &&
federatedSessions == 0 &&
federationClients == 0 &&
readActive == 0 &&
writeActive == 0 {
break
@ -308,7 +310,17 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
h.mu.Lock()
h.ru.Lock()
dumpGoroutines("", os.Stderr)
assert.Fail(t, "Error waiting for hub to terminate", "clients %+v / rooms %+v / sessions %+v / remoteSessions %v / federatedSessions %v / %d read / %d write: %s", h.clients, h.rooms, h.sessions, remoteSessions, federatedSessions, readActive, writeActive, ctx.Err())
assert.Fail(t, "Error waiting for hub to terminate", "clients %+v / rooms %+v / sessions %+v / remoteSessions %v / federatedSessions %v / federationClients %v / %d read / %d write: %s",
h.clients,
h.rooms,
h.sessions,
remoteSessions,
federatedSessions,
federationClients,
readActive,
writeActive,
ctx.Err(),
)
h.ru.Unlock()
h.mu.Unlock()
return