diff --git a/async_events_nats.go b/async_events_nats.go index 04699c6..8742707 100644 --- a/async_events_nats.go +++ b/async_events_nats.go @@ -62,7 +62,7 @@ type asyncSubscriberNats struct { client NatsClient receiver chan *nats.Msg - closeChan chan bool + closeChan chan struct{} subscription NatsSubscription processMessage func(*nats.Msg) @@ -80,7 +80,7 @@ func newAsyncSubscriberNats(key string, client NatsClient) (*asyncSubscriberNats client: client, receiver: receiver, - closeChan: make(chan bool), + closeChan: make(chan struct{}), subscription: sub, } return result, nil diff --git a/channel_waiter.go b/channel_waiter.go new file mode 100644 index 0000000..20b0883 --- /dev/null +++ b/channel_waiter.go @@ -0,0 +1,62 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "sync" +) + +type ChannelWaiters struct { + mu sync.RWMutex + id uint64 + waiters map[uint64]chan struct{} +} + +func (w *ChannelWaiters) Wakeup() { + w.mu.RLock() + defer w.mu.RUnlock() + for _, ch := range w.waiters { + select { + case ch <- struct{}{}: + default: + // Receiver is still processing previous wakeup. + } + } +} + +func (w *ChannelWaiters) Add(ch chan struct{}) uint64 { + w.mu.Lock() + defer w.mu.Unlock() + if w.waiters == nil { + w.waiters = make(map[uint64]chan struct{}) + } + id := w.id + w.id++ + w.waiters[id] = ch + return id +} + +func (w *ChannelWaiters) Remove(id uint64) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.waiters, id) +} diff --git a/channel_waiter_test.go b/channel_waiter_test.go new file mode 100644 index 0000000..e401ae8 --- /dev/null +++ b/channel_waiter_test.go @@ -0,0 +1,66 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "testing" +) + +func TestChannelWaiters(t *testing.T) { + var waiters ChannelWaiters + + ch1 := make(chan struct{}, 1) + id1 := waiters.Add(ch1) + defer waiters.Remove(id1) + + ch2 := make(chan struct{}, 1) + id2 := waiters.Add(ch2) + defer waiters.Remove(id2) + + waiters.Wakeup() + <-ch1 + <-ch2 + + select { + case <-ch1: + t.Error("should have not received another event") + case <-ch2: + t.Error("should have not received another event") + default: + } + + ch3 := make(chan struct{}, 1) + id3 := waiters.Add(ch3) + waiters.Remove(id3) + + // Multiple wakeups work even without processing. + waiters.Wakeup() + waiters.Wakeup() + waiters.Wakeup() + <-ch1 + <-ch2 + select { + case <-ch3: + t.Error("should have not received another event") + default: + } +} diff --git a/client.go b/client.go index 106da2a..7c37232 100644 --- a/client.go +++ b/client.go @@ -93,9 +93,20 @@ type WritableClientMessage interface { CloseAfterSend(session Session) bool } +type ClientHandler interface { + OnClosed(*Client) + OnMessageReceived(*Client, []byte) + OnRTTReceived(*Client, time.Duration) +} + +type ClientGeoIpHandler interface { + OnLookupCountry(*Client) string +} + type Client struct { conn *websocket.Conn addr string + handler ClientHandler agent string closed uint32 country *string @@ -105,18 +116,13 @@ type Client struct { mu sync.Mutex - closeChan chan bool - messagesDone chan bool - messageChan chan *bytes.Buffer - messageProcessing uint32 - - OnLookupCountry func(*Client) string - OnClosed func(*Client) - OnMessageReceived func(*Client, []byte) - OnRTTReceived func(*Client, time.Duration) + closer *Closer + closeOnce sync.Once + messagesDone chan struct{} + messageChan chan *bytes.Buffer } -func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { +func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) { remoteAddress = strings.TrimSpace(remoteAddress) if remoteAddress == "" { remoteAddress = "unknown remote address" @@ -127,31 +133,20 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Clien } client := &Client{ - conn: conn, - addr: remoteAddress, agent: agent, logRTT: true, - - closeChan: make(chan bool, 1), - messageChan: make(chan *bytes.Buffer, 16), - messagesDone: make(chan bool, 1), - - OnLookupCountry: func(client *Client) string { return unknownCountry }, - OnClosed: func(client *Client) {}, - OnMessageReceived: func(client *Client, data []byte) {}, - OnRTTReceived: func(client *Client, rtt time.Duration) {}, } + client.SetConn(conn, remoteAddress, handler) return client, nil } -func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) { +func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) { c.conn = conn c.addr = remoteAddress - c.closeChan = make(chan bool, 1) + c.handler = handler + c.closer = NewCloser() c.messageChan = make(chan *bytes.Buffer, 16) - c.OnLookupCountry = func(client *Client) string { return unknownCountry } - c.OnClosed = func(client *Client) {} - c.OnMessageReceived = func(client *Client, data []byte) {} + c.messagesDone = make(chan struct{}) } func (c *Client) IsConnected() bool { @@ -180,7 +175,12 @@ func (c *Client) UserAgent() string { func (c *Client) Country() string { if c.country == nil { - country := c.OnLookupCountry(c) + var country string + if handler, ok := c.handler.(ClientGeoIpHandler); ok { + country = handler.OnLookupCountry(c) + } else { + country = unknownCountry + } c.country = &country } @@ -188,38 +188,36 @@ func (c *Client) Country() string { } func (c *Client) Close() { - if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + if atomic.LoadUint32(&c.closed) >= 2 { + // Prevent reentrant call in case this was the second closing + // step. Would otherwise deadlock in the "Once.Do" call path + // through "Hub.processUnregister" (which calls "Close" again). return } - c.mu.Lock() - if c.conn != nil { - c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint - } - c.mu.Unlock() - - if atomic.LoadUint32(&c.messageProcessing) == 1 { - // Defer closing - atomic.StoreUint32(&c.closed, 2) - return - } - - c.doClose() + c.closeOnce.Do(func() { + c.doClose() + }) } func (c *Client) doClose() { - c.closeChan <- true - <-c.messagesDone + closed := atomic.AddUint32(&c.closed, 1) + if closed == 1 { + c.mu.Lock() + defer c.mu.Unlock() + if c.conn != nil { + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) // nolint + c.conn.Close() + c.conn = nil + } + } else if closed == 2 { + // Both the read pump and message processing must be finished before closing. + c.closer.Close() + <-c.messagesDone - c.OnClosed(c) - c.SetSession(nil) - - c.mu.Lock() - if c.conn != nil { - c.conn.Close() - c.conn = nil + c.handler.OnClosed(c) + c.SetSession(nil) } - c.mu.Unlock() } func (c *Client) SendError(e *Error) bool { @@ -258,6 +256,8 @@ func (c *Client) ReadPump() { c.Close() }() + go c.processMessages() + addr := c.RemoteAddr() c.mu.Lock() conn := c.conn @@ -284,13 +284,11 @@ func (c *Client) ReadPump() { log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) } } - c.OnRTTReceived(c, rtt) + c.handler.OnRTTReceived(c, rtt) } return nil }) - go c.processMessages() - for { conn.SetReadDeadline(time.Now().Add(pongWait)) // nolint messageType, reader, err := conn.NextReader() @@ -341,22 +339,18 @@ func (c *Client) ReadPump() { } func (c *Client) processMessages() { - atomic.StoreUint32(&c.messageProcessing, 1) for { buffer := <-c.messageChan if buffer == nil { break } - c.OnMessageReceived(c, buffer.Bytes()) + c.handler.OnMessageReceived(c, buffer.Bytes()) bufferPool.Put(buffer) } - atomic.StoreUint32(&c.messageProcessing, 0) - c.messagesDone <- true - if atomic.LoadUint32(&c.closed) == 2 { - c.doClose() - } + close(c.messagesDone) + c.doClose() } func (c *Client) writeInternal(message json.Marshaler) bool { @@ -494,7 +488,7 @@ func (c *Client) WritePump() { if !c.sendPing() { return } - case <-c.closeChan: + case <-c.closer.C: return } } diff --git a/clientsession.go b/clientsession.go index d8b3a45..a10e262 100644 --- a/clientsession.go +++ b/clientsession.go @@ -76,8 +76,7 @@ type ClientSession struct { room unsafe.Pointer roomSessionId string - publisherWaitersId uint64 - publisherWaiters map[uint64]chan bool + publisherWaiters ChannelWaiters publishers map[string]McuPublisher subscribers map[string]McuSubscriber @@ -832,26 +831,6 @@ func (s *ClientSession) checkOfferTypeLocked(streamType string, data *MessageCli return 0, nil } -func (s *ClientSession) wakeupPublisherWaiters() { - for _, ch := range s.publisherWaiters { - ch <- true - } -} - -func (s *ClientSession) addPublisherWaiter(ch chan bool) uint64 { - if s.publisherWaiters == nil { - s.publisherWaiters = make(map[uint64]chan bool) - } - id := s.publisherWaitersId + 1 - s.publisherWaitersId = id - s.publisherWaiters[id] = ch - return id -} - -func (s *ClientSession) removePublisherWaiter(id uint64) { - delete(s.publisherWaiters, id) -} - func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, streamType string, data *MessageClientMessageData) (McuPublisher, error) { s.mu.Lock() defer s.mu.Unlock() @@ -900,7 +879,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea s.publishers[streamType] = publisher } log.Printf("Publishing %s as %s for session %s", streamType, publisher.Id(), s.PublicId()) - s.wakeupPublisherWaiters() + s.publisherWaiters.Wakeup() } else { publisher.SetMedia(mediaTypes) } @@ -928,9 +907,9 @@ func (s *ClientSession) GetOrWaitForPublisher(ctx context.Context, streamType st return publisher } - ch := make(chan bool, 1) - id := s.addPublisherWaiter(ch) - defer s.removePublisherWaiter(id) + ch := make(chan struct{}, 1) + id := s.publisherWaiters.Add(ch) + defer s.publisherWaiters.Remove(id) for { s.mu.Unlock() diff --git a/closer.go b/closer.go new file mode 100644 index 0000000..a68c850 --- /dev/null +++ b/closer.go @@ -0,0 +1,47 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "sync/atomic" +) + +type Closer struct { + closed uint32 + C chan struct{} +} + +func NewCloser() *Closer { + return &Closer{ + C: make(chan struct{}), + } +} + +func (c *Closer) IsClosed() bool { + return atomic.LoadUint32(&c.closed) != 0 +} + +func (c *Closer) Close() { + if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + close(c.C) + } +} diff --git a/closer_test.go b/closer_test.go new file mode 100644 index 0000000..1a14ea1 --- /dev/null +++ b/closer_test.go @@ -0,0 +1,62 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2023 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "sync" + "testing" +) + +func TestCloserMulti(t *testing.T) { + closer := NewCloser() + + var wg sync.WaitGroup + count := 10 + for i := 0; i < count; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-closer.C + }() + } + + if closer.IsClosed() { + t.Error("should not be closed") + } + closer.Close() + if !closer.IsClosed() { + t.Error("should be closed") + } + wg.Wait() +} + +func TestCloserCloseBeforeWait(t *testing.T) { + closer := NewCloser() + closer.Close() + if !closer.IsClosed() { + t.Error("should be closed") + } + <-closer.C + if !closer.IsClosed() { + t.Error("should be closed") + } +} diff --git a/deferred_executor.go b/deferred_executor.go index ee00b57..a6f46c7 100644 --- a/deferred_executor.go +++ b/deferred_executor.go @@ -33,8 +33,7 @@ import ( // their order. type DeferredExecutor struct { queue chan func() - closeChan chan bool - closed chan bool + closed chan struct{} closeOnce sync.Once } @@ -43,28 +42,24 @@ func NewDeferredExecutor(queueSize int) *DeferredExecutor { queueSize = 0 } result := &DeferredExecutor{ - queue: make(chan func(), queueSize), - closeChan: make(chan bool, 1), - closed: make(chan bool, 1), + queue: make(chan func(), queueSize), + closed: make(chan struct{}), } go result.run() return result } func (e *DeferredExecutor) run() { -loop: + defer close(e.closed) + for { - select { - case f := <-e.queue: - if f == nil { - break loop - } - f() - case <-e.closeChan: - break loop + f := <-e.queue + if f == nil { + break } + + f() } - e.closed <- true } func getFunctionName(i interface{}) string { @@ -83,14 +78,9 @@ func (e *DeferredExecutor) Execute(f func()) { } func (e *DeferredExecutor) Close() { - select { - case e.closeChan <- true: - e.closeOnce.Do(func() { - close(e.queue) - }) - default: - // Already closed. - } + e.closeOnce.Do(func() { + close(e.queue) + }) } func (e *DeferredExecutor) waitForStop() { diff --git a/deferred_executor_test.go b/deferred_executor_test.go index 5aa8c08..6e1b12c 100644 --- a/deferred_executor_test.go +++ b/deferred_executor_test.go @@ -109,3 +109,12 @@ func TestDeferredExecutor_DeferAfterClose(t *testing.T) { t.Error("method should not have been called") }) } + +func TestDeferredExecutor_WaitForStopTwice(t *testing.T) { + e := NewDeferredExecutor(64) + defer e.waitForStop() + + e.Close() + + e.waitForStop() +} diff --git a/hub.go b/hub.go index 5afb393..1832d20 100644 --- a/hub.go +++ b/hub.go @@ -119,8 +119,7 @@ type Hub struct { infoInternal *WelcomeServerMessage welcome atomic.Value // *ServerMessage - stopped int32 - stopChan chan bool + closer *Closer readPumpActive uint32 writePumpActive uint32 @@ -314,7 +313,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer info: NewWelcomeServerMessage(version, DefaultFeatures...), infoInternal: NewWelcomeServerMessage(version, DefaultFeaturesInternal...), - stopChan: make(chan bool), + closer: NewCloser(), roomUpdated: make(chan *BackendServerRoomRequest), roomDeleted: make(chan *BackendServerRoomRequest), @@ -417,7 +416,7 @@ func (h *Hub) updateGeoDatabase() { defer atomic.CompareAndSwapInt32(&h.geoipUpdating, 1, 0) delay := time.Second - for atomic.LoadInt32(&h.stopped) == 0 { + for !h.closer.IsClosed() { err := h.geoip.Update() if err == nil { break @@ -458,7 +457,7 @@ loop: h.performHousekeeping(now) case <-geoipUpdater.C: go h.updateGeoDatabase() - case <-h.stopChan: + case <-h.closer.C: break loop } } @@ -468,11 +467,7 @@ loop: } func (h *Hub) Stop() { - atomic.StoreInt32(&h.stopped, 1) - select { - case h.stopChan <- true: - default: - } + h.closer.Close() } func (h *Hub) Reload(config *goconf.ConfigFile) { @@ -2320,20 +2315,12 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { return } - client, err := NewClient(conn, addr, agent) + client, err := NewClient(conn, addr, agent, h) if err != nil { log.Printf("Could not create client for %s: %s", addr, err) return } - if h.geoip != nil { - client.OnLookupCountry = h.lookupClientCountry - } - client.OnMessageReceived = h.processMessage - client.OnClosed = func(client *Client) { - h.processUnregister(client) - } - h.processNewClient(client) go func(h *Hub) { atomic.AddUint32(&h.writePumpActive, 1) @@ -2346,3 +2333,23 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { client.ReadPump() }(h) } + +func (h *Hub) OnLookupCountry(client *Client) string { + if h.geoip == nil { + return unknownCountry + } + + return h.lookupClientCountry(client) +} + +func (h *Hub) OnClosed(client *Client) { + h.processUnregister(client) +} + +func (h *Hub) OnMessageReceived(client *Client, data []byte) { + h.processMessage(client, data) +} + +func (h *Hub) OnRTTReceived(client *Client, rtt time.Duration) { + // Ignore +} diff --git a/janus_client.go b/janus_client.go index e4cfeb4..5dc1991 100644 --- a/janus_client.go +++ b/janus_client.go @@ -172,7 +172,7 @@ func unexpected(request string) error { type transaction struct { ch chan interface{} incoming chan interface{} - quitChan chan bool + closer *Closer } func (t *transaction) run() { @@ -180,7 +180,7 @@ func (t *transaction) run() { select { case msg := <-t.incoming: t.ch <- msg - case <-t.quitChan: + case <-t.closer.C: return } } @@ -191,18 +191,14 @@ func (t *transaction) add(msg interface{}) { } func (t *transaction) quit() { - select { - case t.quitChan <- true: - default: - // Already scheduled to quit. - } + t.closer.Close() } func newTransaction() *transaction { t := &transaction{ ch: make(chan interface{}, 1), incoming: make(chan interface{}, 8), - quitChan: make(chan bool, 1), + closer: NewCloser(), } return t } @@ -239,7 +235,7 @@ type JanusGateway struct { conn *websocket.Conn transactions map[uint64]*transaction - closeChan chan bool + closer *Closer writeMu sync.Mutex } @@ -269,15 +265,16 @@ func NewJanusGateway(wsURL string, listener GatewayListener) (*JanusGateway, err return nil, err } - gateway := new(JanusGateway) - gateway.conn = conn - gateway.transactions = make(map[uint64]*transaction) - gateway.Sessions = make(map[uint64]*JanusSession) - gateway.closeChan = make(chan bool) if listener == nil { listener = new(dummyGatewayListener) } - gateway.listener = listener + gateway := &JanusGateway{ + conn: conn, + listener: listener, + transactions: make(map[uint64]*transaction), + Sessions: make(map[uint64]*JanusSession), + closer: NewCloser(), + } go gateway.ping() go gateway.recv() @@ -286,7 +283,7 @@ func NewJanusGateway(wsURL string, listener GatewayListener) (*JanusGateway, err // Close closes the underlying connection to the Gateway. func (gateway *JanusGateway) Close() error { - gateway.closeChan <- true + gateway.closer.Close() gateway.writeMu.Lock() if gateway.conn == nil { gateway.writeMu.Unlock() @@ -382,7 +379,7 @@ loop: if err != nil { log.Println("Error sending ping to MCU:", err) } - case <-gateway.closeChan: + case <-gateway.closer.C: break loop } } diff --git a/mcu_proxy.go b/mcu_proxy.go index 9c1047d..3e1e9cd 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -305,8 +305,8 @@ type mcuProxyConnection struct { ip net.IP mu sync.Mutex - closeChan chan bool - closedChan chan bool + closer *Closer + closedDone *Closer closed uint32 conn *websocket.Conn @@ -344,8 +344,8 @@ func newMcuProxyConnection(proxy *mcuProxy, baseUrl string, ip net.IP) (*mcuProx rawUrl: baseUrl, url: parsed, ip: ip, - closeChan: make(chan bool, 1), - closedChan: make(chan bool, 1), + closer: NewCloser(), + closedDone: NewCloser(), reconnectInterval: int64(initialReconnectInterval), load: loadNotConnected, callbacks: make(map[string]func(*ProxyServerMessage)), @@ -433,7 +433,7 @@ func (c *mcuProxyConnection) readPump() { if atomic.LoadUint32(&c.closed) == 0 { c.scheduleReconnect() } else { - c.closedChan <- true + c.closedDone.Close() } }() defer c.close() @@ -515,7 +515,7 @@ func (c *mcuProxyConnection) writePump() { c.reconnect() case <-ticker.C: c.sendPing() - case <-c.closeChan: + case <-c.closer.C: return } } @@ -543,7 +543,7 @@ func (c *mcuProxyConnection) stop(ctx context.Context) { return } - c.closeChan <- true + c.closer.Close() if err := c.sendClose(); err != nil { if err != ErrNotConnected { log.Printf("Could not send close message to %s: %s", c, err) @@ -553,7 +553,7 @@ func (c *mcuProxyConnection) stop(ctx context.Context) { } select { - case <-c.closedChan: + case <-c.closedDone.C: case <-ctx.Done(): if err := ctx.Err(); err != nil { log.Printf("Error waiting for connection to %s get closed: %s", c, err) @@ -1124,8 +1124,7 @@ type mcuProxy struct { mu sync.RWMutex publishers map[string]*mcuProxyConnection - publisherWaitersId uint64 - publisherWaiters map[uint64]chan bool + publisherWaiters ChannelWaiters continentsMap atomic.Value @@ -1193,8 +1192,6 @@ func NewMcuProxy(config *goconf.ConfigFile, etcdClient *EtcdClient, rpcClients * publishers: make(map[string]*mcuProxyConnection), - publisherWaiters: make(map[uint64]chan bool), - rpcClients: rpcClients, } @@ -1861,25 +1858,6 @@ func (m *mcuProxy) removePublisher(publisher *mcuProxyPublisher) { delete(m.publishers, publisher.id+"|"+publisher.StreamType()) } -func (m *mcuProxy) wakeupWaiters() { - m.mu.RLock() - defer m.mu.RUnlock() - for _, ch := range m.publisherWaiters { - ch <- true - } -} - -func (m *mcuProxy) addWaiter(ch chan bool) uint64 { - id := m.publisherWaitersId + 1 - m.publisherWaitersId = id - m.publisherWaiters[id] = ch - return id -} - -func (m *mcuProxy) removeWaiter(id uint64) { - delete(m.publisherWaiters, id) -} - func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id string, sid string, streamType string, bitrate int, mediaTypes MediaType, initiator McuInitiator) (McuPublisher, error) { connections := m.getSortedConnections(initiator) for _, conn := range connections { @@ -1910,7 +1888,7 @@ func (m *mcuProxy) NewPublisher(ctx context.Context, listener McuListener, id st m.mu.Lock() m.publishers[id+"|"+streamType] = conn m.mu.Unlock() - m.wakeupWaiters() + m.publisherWaiters.Wakeup() return publisher, nil } @@ -1935,9 +1913,9 @@ func (m *mcuProxy) waitForPublisherConnection(ctx context.Context, publisher str return conn } - ch := make(chan bool, 1) - id := m.addWaiter(ch) - defer m.removeWaiter(id) + ch := make(chan struct{}, 1) + id := m.publisherWaiters.Add(ch) + defer m.publisherWaiters.Remove(id) statsWaitingForPublisherTotal.WithLabelValues(streamType).Inc() for { diff --git a/natsclient_loopback.go b/natsclient_loopback.go index 8c95991..56b6fb6 100644 --- a/natsclient_loopback.go +++ b/natsclient_loopback.go @@ -35,7 +35,6 @@ type LoopbackNatsClient struct { mu sync.Mutex subscriptions map[string]map[*loopbackNatsSubscription]bool - stopping bool wakeup sync.Cond incoming list.List } @@ -53,10 +52,11 @@ func (c *LoopbackNatsClient) processMessages() { c.mu.Lock() defer c.mu.Unlock() for { - for !c.stopping && c.incoming.Len() == 0 { + for c.subscriptions != nil && c.incoming.Len() == 0 { c.wakeup.Wait() } - if c.stopping { + if c.subscriptions == nil { + // Client was closed. break } @@ -91,7 +91,6 @@ func (c *LoopbackNatsClient) Close() { defer c.mu.Unlock() c.subscriptions = nil - c.stopping = true c.incoming.Init() c.wakeup.Signal() } diff --git a/notifier.go b/notifier.go index 94af5bd..3466f45 100644 --- a/notifier.go +++ b/notifier.go @@ -29,7 +29,11 @@ import ( type Waiter struct { key string - SingleWaiter + sw *SingleWaiter +} + +func (w *Waiter) Wait(ctx context.Context) error { + return w.sw.Wait(ctx) } type Notifier struct { @@ -47,22 +51,15 @@ func (n *Notifier) NewWaiter(key string) *Waiter { if found { w := &Waiter{ key: key, - SingleWaiter: SingleWaiter{ - ctx: waiter.ctx, - cancel: waiter.cancel, - }, + sw: waiter.sw, } n.waiterMap[key][w] = true return w } - ctx, cancel := context.WithCancel(context.Background()) waiter = &Waiter{ key: key, - SingleWaiter: SingleWaiter{ - ctx: ctx, - cancel: cancel, - }, + sw: newSingleWaiter(), } if n.waiters == nil { n.waiters = make(map[string]*Waiter) @@ -83,7 +80,7 @@ func (n *Notifier) Reset() { defer n.Unlock() for _, w := range n.waiters { - w.cancel() + w.sw.cancel() } n.waiters = nil n.waiterMap = nil @@ -98,7 +95,7 @@ func (n *Notifier) Release(w *Waiter) { delete(waiters, w) if len(waiters) == 0 { delete(n.waiters, w.key) - w.cancel() + w.sw.cancel() } } } @@ -109,7 +106,7 @@ func (n *Notifier) Notify(key string) { defer n.Unlock() if w, found := n.waiters[key]; found { - w.cancel() + w.sw.cancel() delete(n.waiters, w.key) delete(n.waiterMap, w.key) } diff --git a/proxy/proxy_client.go b/proxy/proxy_client.go index 10ccf7d..c9c495a 100644 --- a/proxy/proxy_client.go +++ b/proxy/proxy_client.go @@ -23,11 +23,11 @@ package main import ( "sync/atomic" + "time" "unsafe" "github.com/gorilla/websocket" - - "github.com/strukturag/nextcloud-spreed-signaling" + signaling "github.com/strukturag/nextcloud-spreed-signaling" ) type ProxyClient struct { @@ -42,7 +42,7 @@ func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*Pro client := &ProxyClient{ proxy: proxy, } - client.SetConn(conn, addr) + client.SetConn(conn, addr, client) return client, nil } @@ -53,3 +53,20 @@ func (c *ProxyClient) GetSession() *ProxySession { func (c *ProxyClient) SetSession(session *ProxySession) { atomic.StorePointer(&c.session, unsafe.Pointer(session)) } + +func (c *ProxyClient) OnClosed(client *signaling.Client) { + if session := c.GetSession(); session != nil { + session.MarkUsed() + } + c.proxy.clientClosed(&c.Client) +} + +func (c *ProxyClient) OnMessageReceived(client *signaling.Client, data []byte) { + c.proxy.processMessage(c, data) +} + +func (c *ProxyClient) OnRTTReceived(client *signaling.Client, rtt time.Duration) { + if session := c.GetSession(); session != nil { + session.MarkUsed() + } +} diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index d1e2938..ac837c3 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -433,21 +433,6 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { return } - client.OnClosed = func(c *signaling.Client) { - if session := client.GetSession(); session != nil { - session.MarkUsed() - } - s.clientClosed(c) - } - client.OnMessageReceived = func(c *signaling.Client, data []byte) { - s.processMessage(client, data) - } - client.OnRTTReceived = func(c *signaling.Client, rtt time.Duration) { - if session := client.GetSession(); session != nil { - session.MarkUsed() - } - } - go client.WritePump() go client.ReadPump() } diff --git a/room.go b/room.go index a3f8e20..a4d0e09 100644 --- a/room.go +++ b/room.go @@ -67,9 +67,9 @@ type Room struct { properties *json.RawMessage - closeChan chan bool - mu *sync.RWMutex - sessions map[string]Session + closer *Closer + mu *sync.RWMutex + sessions map[string]Session internalSessions map[Session]bool virtualSessions map[*VirtualSession]bool @@ -104,9 +104,9 @@ func NewRoom(roomId string, properties *json.RawMessage, hub *Hub, events AsyncE properties: properties, - closeChan: make(chan bool, 1), - mu: &sync.RWMutex{}, - sessions: make(map[string]Session), + closer: NewCloser(), + mu: &sync.RWMutex{}, + sessions: make(map[string]Session), internalSessions: make(map[Session]bool), virtualSessions: make(map[*VirtualSession]bool), @@ -173,7 +173,7 @@ func (r *Room) run() { loop: for { select { - case <-r.closeChan: + case <-r.closer.C: break loop case <-ticker.C: r.publishActiveSessions() @@ -182,10 +182,7 @@ loop: } func (r *Room) doClose() { - select { - case r.closeChan <- true: - default: - } + r.closer.Close() } func (r *Room) unsubscribeBackend() { diff --git a/room_ping.go b/room_ping.go index 48c301a..c51cb91 100644 --- a/room_ping.go +++ b/room_ping.go @@ -63,8 +63,8 @@ func (e *pingEntries) RemoveRoom(room *Room) { // For that, all ping requests across rooms of enabled instances are combined // and sent out batched every "updateActiveSessionsInterval" seconds. type RoomPing struct { - mu sync.Mutex - closeChan chan bool + mu sync.Mutex + closer *Closer backend *BackendClient capabilities *Capabilities @@ -74,7 +74,7 @@ type RoomPing struct { func NewRoomPing(backend *BackendClient, capabilities *Capabilities) (*RoomPing, error) { result := &RoomPing{ - closeChan: make(chan bool, 1), + closer: NewCloser(), backend: backend, capabilities: capabilities, } @@ -87,10 +87,7 @@ func (p *RoomPing) Start() { } func (p *RoomPing) Stop() { - select { - case p.closeChan <- true: - default: - } + p.closer.Close() } func (p *RoomPing) run() { @@ -98,7 +95,7 @@ func (p *RoomPing) run() { loop: for { select { - case <-p.closeChan: + case <-p.closer.C: break loop case <-ticker.C: p.publishActiveSessions() diff --git a/single_notifier.go b/single_notifier.go index 91c4b6f..921542a 100644 --- a/single_notifier.go +++ b/single_notifier.go @@ -27,19 +27,43 @@ import ( ) type SingleWaiter struct { - ctx context.Context - cancel context.CancelFunc + root bool + ch chan struct{} + once sync.Once +} + +func newSingleWaiter() *SingleWaiter { + return &SingleWaiter{ + root: true, + ch: make(chan struct{}), + } +} + +func (w *SingleWaiter) subWaiter() *SingleWaiter { + return &SingleWaiter{ + ch: w.ch, + } } func (w *SingleWaiter) Wait(ctx context.Context) error { select { - case <-w.ctx.Done(): + case <-w.ch: return nil case <-ctx.Done(): return ctx.Err() } } +func (w *SingleWaiter) cancel() { + if !w.root { + return + } + + w.once.Do(func() { + close(w.ch) + }) +} + type SingleNotifier struct { sync.Mutex @@ -52,21 +76,14 @@ func (n *SingleNotifier) NewWaiter() *SingleWaiter { defer n.Unlock() if n.waiter == nil { - ctx, cancel := context.WithCancel(context.Background()) - n.waiter = &SingleWaiter{ - ctx: ctx, - cancel: cancel, - } + n.waiter = newSingleWaiter() } if n.waiters == nil { n.waiters = make(map[*SingleWaiter]bool) } - w := &SingleWaiter{ - ctx: n.waiter.ctx, - cancel: n.waiter.cancel, - } + w := n.waiter.subWaiter() n.waiters[w] = true return w }