diff --git a/client.go b/client.go index ddc179c..69695e0 100644 --- a/client.go +++ b/client.go @@ -100,9 +100,10 @@ type Client struct { mu sync.Mutex - closeChan chan bool - messagesDone sync.WaitGroup - messageChan chan *bytes.Buffer + closeChan chan bool + messagesDone sync.WaitGroup + messageChan chan *bytes.Buffer + messageProcessing uint32 OnLookupCountry func(*Client) string OnClosed func(*Client) @@ -183,6 +184,16 @@ func (c *Client) Close() { return } + if atomic.LoadUint32(&c.messageProcessing) == 1 { + // Defer closing + atomic.StoreUint32(&c.closed, 2) + return + } + + c.doClose() +} + +func (c *Client) doClose() { c.closeChan <- true c.messagesDone.Wait() @@ -304,7 +315,7 @@ func (c *Client) ReadPump() { } // Stop processing if the client was closed. - if atomic.LoadUint32(&c.closed) == 1 { + if atomic.LoadUint32(&c.closed) != 0 { bufferPool.Put(decodeBuffer) break } @@ -321,10 +332,16 @@ func (c *Client) processMessages() { break } + atomic.StoreUint32(&c.messageProcessing, 1) c.OnMessageReceived(c, buffer.Bytes()) + atomic.StoreUint32(&c.messageProcessing, 0) c.messagesDone.Done() bufferPool.Put(buffer) } + + if atomic.LoadUint32(&c.closed) == 2 { + c.doClose() + } } func (c *Client) writeInternal(message json.Marshaler) bool {