From d49d3704fad609ff237ce73f55c79ba9e6829e52 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 8 Feb 2023 09:04:44 +0100 Subject: [PATCH] Use interface for client callbacks. --- client.go | 40 ++++++++++++++++++++++++---------------- hub.go | 30 +++++++++++++++++++++--------- proxy/proxy_client.go | 23 ++++++++++++++++++++--- proxy/proxy_server.go | 15 --------------- 4 files changed, 65 insertions(+), 43 deletions(-) diff --git a/client.go b/client.go index 3a72aad..7c37232 100644 --- a/client.go +++ b/client.go @@ -93,9 +93,20 @@ type WritableClientMessage interface { CloseAfterSend(session Session) bool } +type ClientHandler interface { + OnClosed(*Client) + OnMessageReceived(*Client, []byte) + OnRTTReceived(*Client, time.Duration) +} + +type ClientGeoIpHandler interface { + OnLookupCountry(*Client) string +} + type Client struct { conn *websocket.Conn addr string + handler ClientHandler agent string closed uint32 country *string @@ -109,14 +120,9 @@ type Client struct { closeOnce sync.Once messagesDone chan struct{} messageChan chan *bytes.Buffer - - OnLookupCountry func(*Client) string - OnClosed func(*Client) - OnMessageReceived func(*Client, []byte) - OnRTTReceived func(*Client, time.Duration) } -func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { +func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) { remoteAddress = strings.TrimSpace(remoteAddress) if remoteAddress == "" { remoteAddress = "unknown remote address" @@ -130,20 +136,17 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Clien agent: agent, logRTT: true, } - client.SetConn(conn, remoteAddress) + client.SetConn(conn, remoteAddress, handler) return client, nil } -func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) { +func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) { c.conn = conn c.addr = remoteAddress + c.handler = handler c.closer = NewCloser() c.messageChan = make(chan *bytes.Buffer, 16) c.messagesDone = make(chan struct{}) - c.OnLookupCountry = func(client *Client) string { return unknownCountry } - c.OnClosed = func(client *Client) {} - c.OnMessageReceived = func(client *Client, data []byte) {} - c.OnRTTReceived = func(c *Client, d time.Duration) {} } func (c *Client) IsConnected() bool { @@ -172,7 +175,12 @@ func (c *Client) UserAgent() string { func (c *Client) Country() string { if c.country == nil { - country := c.OnLookupCountry(c) + var country string + if handler, ok := c.handler.(ClientGeoIpHandler); ok { + country = handler.OnLookupCountry(c) + } else { + country = unknownCountry + } c.country = &country } @@ -207,7 +215,7 @@ func (c *Client) doClose() { c.closer.Close() <-c.messagesDone - c.OnClosed(c) + c.handler.OnClosed(c) c.SetSession(nil) } } @@ -276,7 +284,7 @@ func (c *Client) ReadPump() { log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) } } - c.OnRTTReceived(c, rtt) + c.handler.OnRTTReceived(c, rtt) } return nil }) @@ -337,7 +345,7 @@ func (c *Client) processMessages() { break } - c.OnMessageReceived(c, buffer.Bytes()) + c.handler.OnMessageReceived(c, buffer.Bytes()) bufferPool.Put(buffer) } diff --git a/hub.go b/hub.go index d0163f0..1832d20 100644 --- a/hub.go +++ b/hub.go @@ -2315,20 +2315,12 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { return } - client, err := NewClient(conn, addr, agent) + client, err := NewClient(conn, addr, agent, h) if err != nil { log.Printf("Could not create client for %s: %s", addr, err) return } - if h.geoip != nil { - client.OnLookupCountry = h.lookupClientCountry - } - client.OnMessageReceived = h.processMessage - client.OnClosed = func(client *Client) { - h.processUnregister(client) - } - h.processNewClient(client) go func(h *Hub) { atomic.AddUint32(&h.writePumpActive, 1) @@ -2341,3 +2333,23 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { client.ReadPump() }(h) } + +func (h *Hub) OnLookupCountry(client *Client) string { + if h.geoip == nil { + return unknownCountry + } + + return h.lookupClientCountry(client) +} + +func (h *Hub) OnClosed(client *Client) { + h.processUnregister(client) +} + +func (h *Hub) OnMessageReceived(client *Client, data []byte) { + h.processMessage(client, data) +} + +func (h *Hub) OnRTTReceived(client *Client, rtt time.Duration) { + // Ignore +} diff --git a/proxy/proxy_client.go b/proxy/proxy_client.go index 10ccf7d..c9c495a 100644 --- a/proxy/proxy_client.go +++ b/proxy/proxy_client.go @@ -23,11 +23,11 @@ package main import ( "sync/atomic" + "time" "unsafe" "github.com/gorilla/websocket" - - "github.com/strukturag/nextcloud-spreed-signaling" + signaling "github.com/strukturag/nextcloud-spreed-signaling" ) type ProxyClient struct { @@ -42,7 +42,7 @@ func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*Pro client := &ProxyClient{ proxy: proxy, } - client.SetConn(conn, addr) + client.SetConn(conn, addr, client) return client, nil } @@ -53,3 +53,20 @@ func (c *ProxyClient) GetSession() *ProxySession { func (c *ProxyClient) SetSession(session *ProxySession) { atomic.StorePointer(&c.session, unsafe.Pointer(session)) } + +func (c *ProxyClient) OnClosed(client *signaling.Client) { + if session := c.GetSession(); session != nil { + session.MarkUsed() + } + c.proxy.clientClosed(&c.Client) +} + +func (c *ProxyClient) OnMessageReceived(client *signaling.Client, data []byte) { + c.proxy.processMessage(c, data) +} + +func (c *ProxyClient) OnRTTReceived(client *signaling.Client, rtt time.Duration) { + if session := c.GetSession(); session != nil { + session.MarkUsed() + } +} diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index d1e2938..ac837c3 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -433,21 +433,6 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) { return } - client.OnClosed = func(c *signaling.Client) { - if session := client.GetSession(); session != nil { - session.MarkUsed() - } - s.clientClosed(c) - } - client.OnMessageReceived = func(c *signaling.Client, data []byte) { - s.processMessage(client, data) - } - client.OnRTTReceived = func(c *signaling.Client, rtt time.Duration) { - if session := client.GetSession(); session != nil { - session.MarkUsed() - } - } - go client.WritePump() go client.ReadPump() }