Add mutex for "handler" in client.

Fix flaky race as follow-up to #715
This commit is contained in:
Joachim Bauch 2024-04-23 12:41:48 +02:00
parent d368a060fa
commit f8899ef189
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02

View file

@ -123,12 +123,14 @@ type ClientGeoIpHandler interface {
type Client struct { type Client struct {
conn *websocket.Conn conn *websocket.Conn
addr string addr string
handler ClientHandler
agent string agent string
closed atomic.Int32 closed atomic.Int32
country *string country *string
logRTT bool logRTT bool
handlerMu sync.RWMutex
handler ClientHandler
session atomic.Pointer[Session] session atomic.Pointer[Session]
sessionId atomic.Pointer[string] sessionId atomic.Pointer[string]
@ -168,9 +170,17 @@ func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler Cli
} }
func (c *Client) SetHandler(handler ClientHandler) { func (c *Client) SetHandler(handler ClientHandler) {
c.handlerMu.Lock()
defer c.handlerMu.Unlock()
c.handler = handler c.handler = handler
} }
func (c *Client) getHandler() ClientHandler {
c.handlerMu.RLock()
defer c.handlerMu.RUnlock()
return c.handler
}
func (c *Client) IsConnected() bool { func (c *Client) IsConnected() bool {
return c.closed.Load() == 0 return c.closed.Load() == 0
} }
@ -225,7 +235,7 @@ func (c *Client) UserAgent() string {
func (c *Client) Country() string { func (c *Client) Country() string {
if c.country == nil { if c.country == nil {
var country string var country string
if handler, ok := c.handler.(ClientGeoIpHandler); ok { if handler, ok := c.getHandler().(ClientGeoIpHandler); ok {
country = handler.OnLookupCountry(c) country = handler.OnLookupCountry(c)
} else { } else {
country = unknownCountry country = unknownCountry
@ -264,7 +274,7 @@ func (c *Client) doClose() {
c.closer.Close() c.closer.Close()
<-c.messagesDone <-c.messagesDone
c.handler.OnClosed(c) c.getHandler().OnClosed(c)
c.SetSession(nil) c.SetSession(nil)
} }
} }
@ -335,7 +345,7 @@ func (c *Client) ReadPump() {
log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt)
} }
} }
c.handler.OnRTTReceived(c, rtt) c.getHandler().OnRTTReceived(c, rtt)
} }
return nil return nil
}) })
@ -396,7 +406,7 @@ func (c *Client) processMessages() {
break break
} }
c.handler.OnMessageReceived(c, buffer.Bytes()) c.getHandler().OnMessageReceived(c, buffer.Bytes())
bufferPool.Put(buffer) bufferPool.Put(buffer)
} }