Use interface for client callbacks.

This commit is contained in:
Joachim Bauch 2023-02-08 09:04:44 +01:00
parent e9f80c6b4d
commit d49d3704fa
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
4 changed files with 65 additions and 43 deletions

View file

@ -93,9 +93,20 @@ type WritableClientMessage interface {
CloseAfterSend(session Session) bool CloseAfterSend(session Session) bool
} }
type ClientHandler interface {
OnClosed(*Client)
OnMessageReceived(*Client, []byte)
OnRTTReceived(*Client, time.Duration)
}
type ClientGeoIpHandler interface {
OnLookupCountry(*Client) string
}
type Client struct { type Client struct {
conn *websocket.Conn conn *websocket.Conn
addr string addr string
handler ClientHandler
agent string agent string
closed uint32 closed uint32
country *string country *string
@ -109,14 +120,9 @@ type Client struct {
closeOnce sync.Once closeOnce sync.Once
messagesDone chan struct{} messagesDone chan struct{}
messageChan chan *bytes.Buffer messageChan chan *bytes.Buffer
OnLookupCountry func(*Client) string
OnClosed func(*Client)
OnMessageReceived func(*Client, []byte)
OnRTTReceived func(*Client, time.Duration)
} }
func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Client, error) { func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) {
remoteAddress = strings.TrimSpace(remoteAddress) remoteAddress = strings.TrimSpace(remoteAddress)
if remoteAddress == "" { if remoteAddress == "" {
remoteAddress = "unknown remote address" remoteAddress = "unknown remote address"
@ -130,20 +136,17 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string) (*Clien
agent: agent, agent: agent,
logRTT: true, logRTT: true,
} }
client.SetConn(conn, remoteAddress) client.SetConn(conn, remoteAddress, handler)
return client, nil return client, nil
} }
func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string) { func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) {
c.conn = conn c.conn = conn
c.addr = remoteAddress c.addr = remoteAddress
c.handler = handler
c.closer = NewCloser() c.closer = NewCloser()
c.messageChan = make(chan *bytes.Buffer, 16) c.messageChan = make(chan *bytes.Buffer, 16)
c.messagesDone = make(chan struct{}) c.messagesDone = make(chan struct{})
c.OnLookupCountry = func(client *Client) string { return unknownCountry }
c.OnClosed = func(client *Client) {}
c.OnMessageReceived = func(client *Client, data []byte) {}
c.OnRTTReceived = func(c *Client, d time.Duration) {}
} }
func (c *Client) IsConnected() bool { func (c *Client) IsConnected() bool {
@ -172,7 +175,12 @@ func (c *Client) UserAgent() string {
func (c *Client) Country() string { func (c *Client) Country() string {
if c.country == nil { if c.country == nil {
country := c.OnLookupCountry(c) var country string
if handler, ok := c.handler.(ClientGeoIpHandler); ok {
country = handler.OnLookupCountry(c)
} else {
country = unknownCountry
}
c.country = &country c.country = &country
} }
@ -207,7 +215,7 @@ func (c *Client) doClose() {
c.closer.Close() c.closer.Close()
<-c.messagesDone <-c.messagesDone
c.OnClosed(c) c.handler.OnClosed(c)
c.SetSession(nil) c.SetSession(nil)
} }
} }
@ -276,7 +284,7 @@ func (c *Client) ReadPump() {
log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt)
} }
} }
c.OnRTTReceived(c, rtt) c.handler.OnRTTReceived(c, rtt)
} }
return nil return nil
}) })
@ -337,7 +345,7 @@ func (c *Client) processMessages() {
break break
} }
c.OnMessageReceived(c, buffer.Bytes()) c.handler.OnMessageReceived(c, buffer.Bytes())
bufferPool.Put(buffer) bufferPool.Put(buffer)
} }

30
hub.go
View file

@ -2315,20 +2315,12 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
return return
} }
client, err := NewClient(conn, addr, agent) client, err := NewClient(conn, addr, agent, h)
if err != nil { if err != nil {
log.Printf("Could not create client for %s: %s", addr, err) log.Printf("Could not create client for %s: %s", addr, err)
return return
} }
if h.geoip != nil {
client.OnLookupCountry = h.lookupClientCountry
}
client.OnMessageReceived = h.processMessage
client.OnClosed = func(client *Client) {
h.processUnregister(client)
}
h.processNewClient(client) h.processNewClient(client)
go func(h *Hub) { go func(h *Hub) {
atomic.AddUint32(&h.writePumpActive, 1) atomic.AddUint32(&h.writePumpActive, 1)
@ -2341,3 +2333,23 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
client.ReadPump() client.ReadPump()
}(h) }(h)
} }
func (h *Hub) OnLookupCountry(client *Client) string {
if h.geoip == nil {
return unknownCountry
}
return h.lookupClientCountry(client)
}
func (h *Hub) OnClosed(client *Client) {
h.processUnregister(client)
}
func (h *Hub) OnMessageReceived(client *Client, data []byte) {
h.processMessage(client, data)
}
func (h *Hub) OnRTTReceived(client *Client, rtt time.Duration) {
// Ignore
}

View file

@ -23,11 +23,11 @@ package main
import ( import (
"sync/atomic" "sync/atomic"
"time"
"unsafe" "unsafe"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
signaling "github.com/strukturag/nextcloud-spreed-signaling"
"github.com/strukturag/nextcloud-spreed-signaling"
) )
type ProxyClient struct { type ProxyClient struct {
@ -42,7 +42,7 @@ func NewProxyClient(proxy *ProxyServer, conn *websocket.Conn, addr string) (*Pro
client := &ProxyClient{ client := &ProxyClient{
proxy: proxy, proxy: proxy,
} }
client.SetConn(conn, addr) client.SetConn(conn, addr, client)
return client, nil return client, nil
} }
@ -53,3 +53,20 @@ func (c *ProxyClient) GetSession() *ProxySession {
func (c *ProxyClient) SetSession(session *ProxySession) { func (c *ProxyClient) SetSession(session *ProxySession) {
atomic.StorePointer(&c.session, unsafe.Pointer(session)) atomic.StorePointer(&c.session, unsafe.Pointer(session))
} }
func (c *ProxyClient) OnClosed(client *signaling.Client) {
if session := c.GetSession(); session != nil {
session.MarkUsed()
}
c.proxy.clientClosed(&c.Client)
}
func (c *ProxyClient) OnMessageReceived(client *signaling.Client, data []byte) {
c.proxy.processMessage(c, data)
}
func (c *ProxyClient) OnRTTReceived(client *signaling.Client, rtt time.Duration) {
if session := c.GetSession(); session != nil {
session.MarkUsed()
}
}

View file

@ -433,21 +433,6 @@ func (s *ProxyServer) proxyHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
client.OnClosed = func(c *signaling.Client) {
if session := client.GetSession(); session != nil {
session.MarkUsed()
}
s.clientClosed(c)
}
client.OnMessageReceived = func(c *signaling.Client, data []byte) {
s.processMessage(client, data)
}
client.OnRTTReceived = func(c *signaling.Client, rtt time.Duration) {
if session := client.GetSession(); session != nil {
session.MarkUsed()
}
}
go client.WritePump() go client.WritePump()
go client.ReadPump() go client.ReadPump()
} }