Simplify close code of client to make clear when it gets closed internally.

This commit is contained in:
Joachim Bauch 2023-01-19 14:34:51 +01:00
parent b17eb584b4
commit 758899b745
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02

View file

@ -105,10 +105,10 @@ type Client struct {
mu sync.Mutex
closeChan chan bool
messagesDone chan bool
messageChan chan *bytes.Buffer
messageProcessing uint32
closeChan chan struct{}
closeOnce sync.Once
messagesDone chan struct{}
messageChan chan *bytes.Buffer
OnLookupCountry func(*Client) string
OnClosed func(*Client)
@ -127,31 +127,23 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Clien
}
client := &Client{
conn: conn,
addr: remoteAddress,
agent: agent,
logRTT: true,
closeChan: make(chan bool, 1),
messageChan: make(chan *bytes.Buffer, 16),
messagesDone: make(chan bool, 1),
OnLookupCountry: func(client *Client) string { return unknownCountry },
OnClosed: func(client *Client) {},
OnMessageReceived: func(client *Client, data []byte) {},
OnRTTReceived: func(client *Client, rtt time.Duration) {},
}
client.SetConn(conn, remoteAddress)
return client, nil
}
func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) {
c.conn = conn
c.addr = remoteAddress
c.closeChan = make(chan bool, 1)
c.closeChan = make(chan struct{})
c.messageChan = make(chan *bytes.Buffer, 16)
c.messagesDone = make(chan struct{})
c.OnLookupCountry = func(client *Client) string { return unknownCountry }
c.OnClosed = func(client *Client) {}
c.OnMessageReceived = func(client *Client, data []byte) {}
c.OnRTTReceived = func(c *Client, d time.Duration) {}
}
func (c *Client) IsConnected() bool {
@ -188,38 +180,36 @@ func (c *Client) Country() string {
}
func (c *Client) Close() {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
if atomic.LoadUint32(&c.closed) >= 2 {
// Prevent reentrant call in case this was the second closing
// step. Would otherwise deadlock in the "Once.Do" call path
// through "Hub.processUnregister" (which calls "Close" again).
return
}
c.mu.Lock()
if c.conn != nil {
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint
}
c.mu.Unlock()
if atomic.LoadUint32(&c.messageProcessing) == 1 {
// Defer closing
atomic.StoreUint32(&c.closed, 2)
return
}
c.doClose()
c.closeOnce.Do(func() {
c.doClose()
})
}
func (c *Client) doClose() {
c.closeChan <- true
<-c.messagesDone
closed := atomic.AddUint32(&c.closed, 1)
if closed == 1 {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint
c.conn.Close()
c.conn = nil
}
} else if closed == 2 {
// Both the read pump and message processing must be finished before closing.
close(c.closeChan)
<-c.messagesDone
c.OnClosed(c)
c.SetSession(nil)
c.mu.Lock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
c.OnClosed(c)
c.SetSession(nil)
}
c.mu.Unlock()
}
func (c *Client) SendError(e *Error) bool {
@ -341,7 +331,6 @@ func (c *Client) ReadPump() {
}
func (c *Client) processMessages() {
atomic.StoreUint32(&c.messageProcessing, 1)
for {
buffer := <-c.messageChan
if buffer == nil {
@ -351,12 +340,9 @@ func (c *Client) processMessages() {
c.OnMessageReceived(c, buffer.Bytes())
bufferPool.Put(buffer)
}
atomic.StoreUint32(&c.messageProcessing, 0)
c.messagesDone <- true
if atomic.LoadUint32(&c.closed) == 2 {
c.doClose()
}
close(c.messagesDone)
c.doClose()
}
func (c *Client) writeInternal(message json.Marshaler) bool {