diff --git a/client.go b/client.go index 106da2a..434f138 100644 --- a/client.go +++ b/client.go @@ -105,10 +105,10 @@ type Client struct { mu sync.Mutex - closeChan chan bool - messagesDone chan bool - messageChan chan *bytes.Buffer - messageProcessing uint32 + closeChan chan struct{} + closeOnce sync.Once + messagesDone chan struct{} + messageChan chan *bytes.Buffer OnLookupCountry func(*Client) string OnClosed func(*Client) @@ -127,31 +127,23 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Clien } client := &Client{ - conn: conn, - addr: remoteAddress, agent: agent, logRTT: true, - - closeChan: make(chan bool, 1), - messageChan: make(chan *bytes.Buffer, 16), - messagesDone: make(chan bool, 1), - - OnLookupCountry: func(client *Client) string { return unknownCountry }, - OnClosed: func(client *Client) {}, - OnMessageReceived: func(client *Client, data []byte) {}, - OnRTTReceived: func(client *Client, rtt time.Duration) {}, } + client.SetConn(conn, remoteAddress) return client, nil } func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) { c.conn = conn c.addr = remoteAddress - c.closeChan = make(chan bool, 1) + c.closeChan = make(chan struct{}) c.messageChan = make(chan *bytes.Buffer, 16) + c.messagesDone = make(chan struct{}) c.OnLookupCountry = func(client *Client) string { return unknownCountry } c.OnClosed = func(client *Client) {} c.OnMessageReceived = func(client *Client, data []byte) {} + c.OnRTTReceived = func(c *Client, d time.Duration) {} } func (c *Client) IsConnected() bool { @@ -188,38 +180,36 @@ func (c *Client) Country() string { } func (c *Client) Close() { - if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + if atomic.LoadUint32(&c.closed) >= 2 { + // Prevent reentrant call in case this was the second closing + // step. Would otherwise deadlock in the "Once.Do" call path + // through "Hub.processUnregister" (which calls "Close" again). return } - c.mu.Lock() - if c.conn != nil { - c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint - } - c.mu.Unlock() - - if atomic.LoadUint32(&c.messageProcessing) == 1 { - // Defer closing - atomic.StoreUint32(&c.closed, 2) - return - } - - c.doClose() + c.closeOnce.Do(func() { + c.doClose() + }) } func (c *Client) doClose() { - c.closeChan <- true - <-c.messagesDone + closed := atomic.AddUint32(&c.closed, 1) + if closed == 1 { + c.mu.Lock() + defer c.mu.Unlock() + if c.conn != nil { + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint + c.conn.Close() + c.conn = nil + } + } else if closed == 2 { + // Both the read pump and message processing must be finished before closing. + close(c.closeChan) + <-c.messagesDone - c.OnClosed(c) - c.SetSession(nil) - - c.mu.Lock() - if c.conn != nil { - c.conn.Close() - c.conn = nil + c.OnClosed(c) + c.SetSession(nil) } - c.mu.Unlock() } func (c *Client) SendError(e *Error) bool { @@ -341,7 +331,6 @@ func (c *Client) ReadPump() { } func (c *Client) processMessages() { - atomic.StoreUint32(&c.messageProcessing, 1) for { buffer := <-c.messageChan if buffer == nil { @@ -351,12 +340,9 @@ func (c *Client) processMessages() { c.OnMessageReceived(c, buffer.Bytes()) bufferPool.Put(buffer) } - atomic.StoreUint32(&c.messageProcessing, 0) - c.messagesDone <- true - if atomic.LoadUint32(&c.closed) == 2 { - c.doClose() - } + close(c.messagesDone) + c.doClose() } func (c *Client) writeInternal(message json.Marshaler) bool {