diff --git a/client.go b/client.go index 948ad97..dae3a11 100644 --- a/client.go +++ b/client.go @@ -123,12 +123,14 @@ type ClientGeoIpHandler interface { type Client struct { conn *websocket.Conn addr string - handler ClientHandler agent string closed atomic.Int32 country *string logRTT bool + handlerMu sync.RWMutex + handler ClientHandler + session atomic.Pointer[Session] sessionId atomic.Pointer[string] @@ -168,9 +170,17 @@ func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler Cli } func (c *Client) SetHandler(handler ClientHandler) { + c.handlerMu.Lock() + defer c.handlerMu.Unlock() c.handler = handler } +func (c *Client) getHandler() ClientHandler { + c.handlerMu.RLock() + defer c.handlerMu.RUnlock() + return c.handler +} + func (c *Client) IsConnected() bool { return c.closed.Load() == 0 } @@ -225,7 +235,7 @@ func (c *Client) UserAgent() string { func (c *Client) Country() string { if c.country == nil { var country string - if handler, ok := c.handler.(ClientGeoIpHandler); ok { + if handler, ok := c.getHandler().(ClientGeoIpHandler); ok { country = handler.OnLookupCountry(c) } else { country = unknownCountry @@ -264,7 +274,7 @@ func (c *Client) doClose() { c.closer.Close() <-c.messagesDone - c.handler.OnClosed(c) + c.getHandler().OnClosed(c) c.SetSession(nil) } } @@ -335,7 +345,7 @@ func (c *Client) ReadPump() { log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) } } - c.handler.OnRTTReceived(c, rtt) + c.getHandler().OnRTTReceived(c, rtt) } return nil }) @@ -396,7 +406,7 @@ func (c *Client) processMessages() { break } - c.handler.OnMessageReceived(c, buffer.Bytes()) + c.getHandler().OnMessageReceived(c, buffer.Bytes()) bufferPool.Put(buffer) }