Add "HandlerClient" interface to support custom implementations.

This commit is contained in:
Joachim Bauch 2024-04-23 10:23:13 +02:00
parent 3721fb131f
commit 2468443572
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
8 changed files with 202 additions and 111 deletions

106
client.go
View file

@ -92,14 +92,32 @@ type WritableClientMessage interface {
CloseAfterSend(session Session) bool
}
type HandlerClient interface {
RemoteAddr() string
Country() string
UserAgent() string
IsConnected() bool
IsAuthenticated() bool
GetSession() Session
SetSession(session Session)
SendError(e *Error) bool
SendByeResponse(message *ClientMessage) bool
SendByeResponseWithReason(message *ClientMessage, reason string) bool
SendMessage(message WritableClientMessage) bool
Close()
}
type ClientHandler interface {
OnClosed(*Client)
OnMessageReceived(*Client, []byte)
OnRTTReceived(*Client, time.Duration)
OnClosed(HandlerClient)
OnMessageReceived(HandlerClient, []byte)
OnRTTReceived(HandlerClient, time.Duration)
}
type ClientGeoIpHandler interface {
OnLookupCountry(*Client) string
OnLookupCountry(HandlerClient) string
}
type Client struct {
@ -111,7 +129,8 @@ type Client struct {
country *string
logRTT bool
session atomic.Pointer[ClientSession]
session atomic.Pointer[Session]
sessionId atomic.Pointer[string]
mu sync.Mutex
@ -142,12 +161,16 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler
func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) {
c.conn = conn
c.addr = remoteAddress
c.handler = handler
c.SetHandler(handler)
c.closer = NewCloser()
c.messageChan = make(chan *bytes.Buffer, 16)
c.messagesDone = make(chan struct{})
}
func (c *Client) SetHandler(handler ClientHandler) {
c.handler = handler
}
func (c *Client) IsConnected() bool {
return c.closed.Load() == 0
}
@ -156,12 +179,39 @@ func (c *Client) IsAuthenticated() bool {
return c.GetSession() != nil
}
func (c *Client) GetSession() *ClientSession {
return c.session.Load()
func (c *Client) GetSession() Session {
session := c.session.Load()
if session == nil {
return nil
}
return *session
}
func (c *Client) SetSession(session *ClientSession) {
c.session.Store(session)
func (c *Client) SetSession(session Session) {
if session == nil {
c.session.Store(nil)
} else {
c.session.Store(&session)
}
}
func (c *Client) SetSessionId(sessionId string) {
c.sessionId.Store(&sessionId)
}
func (c *Client) GetSessionId() string {
sessionId := c.sessionId.Load()
if sessionId == nil {
session := c.GetSession()
if session == nil {
return ""
}
return session.PublicId()
}
return *sessionId
}
func (c *Client) RemoteAddr() string {
@ -234,12 +284,14 @@ func (c *Client) SendByeResponse(message *ClientMessage) bool {
func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string) bool {
response := &ServerMessage{
Type: "bye",
Bye: &ByeServerMessage{},
}
if message != nil {
response.Id = message.Id
}
if reason != "" {
if response.Bye == nil {
response.Bye = &ByeServerMessage{}
}
response.Bye.Reason = reason
}
return c.SendMessage(response)
@ -277,8 +329,8 @@ func (c *Client) ReadPump() {
rtt := now.Sub(time.Unix(0, ts))
if c.logRTT {
rtt_ms := rtt.Nanoseconds() / time.Millisecond.Nanoseconds()
if session := c.GetSession(); session != nil {
log.Printf("Client %s has RTT of %d ms (%s)", session.PublicId(), rtt_ms, rtt)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Client %s has RTT of %d ms (%s)", sessionId, rtt_ms, rtt)
} else {
log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt)
}
@ -296,8 +348,8 @@ func (c *Client) ReadPump() {
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived) {
if session := c.GetSession(); session != nil {
log.Printf("Error reading from client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Error reading from client %s: %v", sessionId, err)
} else {
log.Printf("Error reading from %s: %v", addr, err)
}
@ -306,8 +358,8 @@ func (c *Client) ReadPump() {
}
if messageType != websocket.TextMessage {
if session := c.GetSession(); session != nil {
log.Printf("Unsupported message type %v from client %s", messageType, session.PublicId())
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Unsupported message type %v from client %s", messageType, sessionId)
} else {
log.Printf("Unsupported message type %v from %s", messageType, addr)
}
@ -319,8 +371,8 @@ func (c *Client) ReadPump() {
decodeBuffer.Reset()
if _, err := decodeBuffer.ReadFrom(reader); err != nil {
bufferPool.Put(decodeBuffer)
if session := c.GetSession(); session != nil {
log.Printf("Error reading message from client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Error reading message from client %s: %v", sessionId, err)
} else {
log.Printf("Error reading message from %s: %v", addr, err)
}
@ -373,8 +425,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool {
return false
}
if session := c.GetSession(); session != nil {
log.Printf("Could not send message %+v to client %s: %v", message, session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Could not send message %+v to client %s: %v", message, sessionId, err)
} else {
log.Printf("Could not send message %+v to %s: %v", message, c.RemoteAddr(), err)
}
@ -386,8 +438,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool {
close:
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil {
if session := c.GetSession(); session != nil {
log.Printf("Could not send close message to client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Could not send close message to client %s: %v", sessionId, err)
} else {
log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err)
}
@ -413,8 +465,8 @@ func (c *Client) writeError(e error) bool { // nolint
closeData := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, e.Error())
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil {
if session := c.GetSession(); session != nil {
log.Printf("Could not send close message to client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Could not send close message to client %s: %v", sessionId, err)
} else {
log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err)
}
@ -462,8 +514,8 @@ func (c *Client) sendPing() bool {
msg := strconv.FormatInt(now, 10)
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil {
if session := c.GetSession(); session != nil {
log.Printf("Could not send ping to client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Could not send ping to client %s: %v", sessionId, err)
} else {
log.Printf("Could not send ping to %s: %v", c.RemoteAddr(), err)
}

View file

@ -67,7 +67,7 @@ type ClientSession struct {
mu sync.Mutex
client *Client
client HandlerClient
room atomic.Pointer[Room]
roomJoinTime atomic.Int64
roomSessionId string
@ -500,14 +500,14 @@ func (s *ClientSession) doUnsubscribeRoomEvents(notify bool) {
s.roomSessionId = ""
}
func (s *ClientSession) ClearClient(client *Client) {
func (s *ClientSession) ClearClient(client HandlerClient) {
s.mu.Lock()
defer s.mu.Unlock()
s.clearClientLocked(client)
}
func (s *ClientSession) clearClientLocked(client *Client) {
func (s *ClientSession) clearClientLocked(client HandlerClient) {
if s.client == nil {
return
} else if client != nil && s.client != client {
@ -520,18 +520,18 @@ func (s *ClientSession) clearClientLocked(client *Client) {
prevClient.SetSession(nil)
}
func (s *ClientSession) GetClient() *Client {
func (s *ClientSession) GetClient() HandlerClient {
s.mu.Lock()
defer s.mu.Unlock()
return s.getClientUnlocked()
}
func (s *ClientSession) getClientUnlocked() *Client {
func (s *ClientSession) getClientUnlocked() HandlerClient {
return s.client
}
func (s *ClientSession) SetClient(client *Client) *Client {
func (s *ClientSession) SetClient(client HandlerClient) HandlerClient {
if client == nil {
panic("Use ClearClient to set the client to nil")
}
@ -1341,7 +1341,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage {
}
}
func (s *ClientSession) NotifySessionResumed(client *Client) {
func (s *ClientSession) NotifySessionResumed(client HandlerClient) {
s.mu.Lock()
if len(s.pendingClientMessages) == 0 {
s.mu.Unlock()

148
hub.go
View file

@ -135,7 +135,7 @@ type Hub struct {
ru sync.RWMutex
sid atomic.Uint64
clients map[uint64]*Client
clients map[uint64]HandlerClient
sessions map[uint64]Session
rooms map[string]*Room
@ -153,7 +153,7 @@ type Hub struct {
expiredSessions map[Session]time.Time
anonymousSessions map[*ClientSession]time.Time
expectHelloClients map[*Client]time.Time
expectHelloClients map[HandlerClient]time.Time
dialoutSessions map[*ClientSession]bool
backendTimeout time.Duration
@ -324,7 +324,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
roomInCall: make(chan *BackendServerRoomRequest),
roomParticipants: make(chan *BackendServerRoomRequest),
clients: make(map[uint64]*Client),
clients: make(map[uint64]HandlerClient),
sessions: make(map[uint64]Session),
rooms: make(map[string]*Room),
@ -341,7 +341,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
expiredSessions: make(map[Session]time.Time),
anonymousSessions: make(map[*ClientSession]time.Time),
expectHelloClients: make(map[*Client]time.Time),
expectHelloClients: make(map[HandlerClient]time.Time),
dialoutSessions: make(map[*ClientSession]bool),
backendTimeout: backendTimeout,
@ -690,15 +690,13 @@ func (h *Hub) startWaitAnonymousSessionRoomLocked(session *ClientSession) {
h.anonymousSessions[session] = now.Add(anonmyousJoinRoomTimeout)
}
func (h *Hub) startExpectHello(client *Client) {
func (h *Hub) startExpectHello(client HandlerClient) {
h.mu.Lock()
defer h.mu.Unlock()
if !client.IsConnected() {
return
}
client.mu.Lock()
defer client.mu.Unlock()
if client.IsAuthenticated() {
return
}
@ -708,12 +706,12 @@ func (h *Hub) startExpectHello(client *Client) {
h.expectHelloClients[client] = now.Add(initialHelloTimeout)
}
func (h *Hub) processNewClient(client *Client) {
func (h *Hub) processNewClient(client HandlerClient) {
h.startExpectHello(client)
h.sendWelcome(client)
}
func (h *Hub) sendWelcome(client *Client) {
func (h *Hub) sendWelcome(client HandlerClient) {
client.SendMessage(h.getWelcomeMessage())
}
@ -730,17 +728,24 @@ func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {
return sessionIdData
}
func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *Backend, auth *BackendClientResponse) {
if !client.IsConnected() {
func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *Backend, auth *BackendClientResponse) {
if !c.IsConnected() {
// Client disconnected while waiting for "hello" response.
return
}
if auth.Type == "error" {
client.SendMessage(message.NewErrorServerMessage(auth.Error))
c.SendMessage(message.NewErrorServerMessage(auth.Error))
return
} else if auth.Type != "auth" {
client.SendMessage(message.NewErrorServerMessage(UserAuthFailed))
c.SendMessage(message.NewErrorServerMessage(UserAuthFailed))
return
}
client, ok := c.(*Client)
if !ok {
log.Printf("Can't register non-client %T", c)
client.SendMessage(message.NewWrappedErrorServerMessage(errors.New("can't register non-client")))
return
}
@ -844,7 +849,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
h.sendHelloResponse(session, message)
}
func (h *Hub) processUnregister(client *Client) *ClientSession {
func (h *Hub) processUnregister(client HandlerClient) Session {
session := client.GetSession()
h.mu.Lock()
@ -857,14 +862,18 @@ func (h *Hub) processUnregister(client *Client) *ClientSession {
h.mu.Unlock()
if session != nil {
log.Printf("Unregister %s (private=%s)", session.PublicId(), session.PrivateId())
session.ClearClient(client)
if c, ok := client.(*Client); ok {
if cs, ok := session.(*ClientSession); ok {
cs.ClearClient(c)
}
}
}
client.Close()
return session
}
func (h *Hub) processMessage(client *Client, data []byte) {
func (h *Hub) processMessage(client HandlerClient, data []byte) {
var message ClientMessage
if err := message.UnmarshalJSON(data); err != nil {
if session := client.GetSession(); session != nil {
@ -944,7 +953,7 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage)
return session.SendMessage(response)
}
func (h *Hub) processHello(client *Client, message *ClientMessage) {
func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
resumeId := message.Hello.ResumeId
if resumeId != "" {
data := h.decodeSessionId(resumeId, privateSessionName)
@ -1013,7 +1022,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) {
}
}
func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
url := message.Hello.Auth.parsedUrl
backend := h.backend.GetBackend(url)
if backend == nil {
@ -1035,7 +1044,7 @@ func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend,
return backend, &auth, nil
}
func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
url := message.Hello.Auth.parsedUrl
backend := h.backend.GetBackend(url)
if backend == nil {
@ -1141,11 +1150,11 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend,
return backend, auth, nil
}
func (h *Hub) processHelloClient(client *Client, message *ClientMessage) {
func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) {
// Make sure the client must send another "hello" in case of errors.
defer h.startExpectHello(client)
var authFunc func(*Client, *ClientMessage) (*Backend, *BackendClientResponse, error)
var authFunc func(HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error)
switch message.Hello.Version {
case HelloVersionV1:
// Auth information contains a ticket that must be validated against the
@ -1172,7 +1181,7 @@ func (h *Hub) processHelloClient(client *Client, message *ClientMessage) {
h.processRegister(client, message, backend, auth)
}
func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) {
func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage) {
defer h.startExpectHello(client)
if len(h.internalClientsSecret) == 0 {
client.SendMessage(message.NewErrorServerMessage(InvalidClientType))
@ -1261,8 +1270,12 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo
return session.SendMessage(response)
}
func (h *Hub) processRoom(client *Client, message *ClientMessage) {
session := client.GetSession()
func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
session, ok := client.GetSession().(*ClientSession)
if !ok {
return
}
roomId := message.Room.RoomId
if roomId == "" {
if session == nil {
@ -1281,29 +1294,34 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) {
return
}
if session != nil {
if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) {
// Session already is in that room, no action needed.
roomSessionId := message.Room.SessionId
if roomSessionId == "" {
// TODO(jojo): Better make the session id required in the request.
log.Printf("User did not send a room session id, assuming session %s", session.PublicId())
roomSessionId = session.PublicId()
}
if session == nil {
session.SendMessage(message.NewErrorServerMessage(
NewError("not_authenticated", "Need to authenticate before joining rooms."),
))
return
}
if err := session.UpdateRoomSessionId(roomSessionId); err != nil {
log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err)
}
session.SendMessage(message.NewErrorServerMessage(
NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{
Room: &RoomServerMessage{
RoomId: room.id,
Properties: room.properties,
},
}),
))
return
if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) {
// Session already is in that room, no action needed.
roomSessionId := message.Room.SessionId
if roomSessionId == "" {
// TODO(jojo): Better make the session id required in the request.
log.Printf("User did not send a room session id, assuming session %s", session.PublicId())
roomSessionId = session.PublicId()
}
if err := session.UpdateRoomSessionId(roomSessionId); err != nil {
log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err)
}
session.SendMessage(message.NewErrorServerMessage(
NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{
Room: &RoomServerMessage{
RoomId: room.id,
Properties: room.properties,
},
}),
))
return
}
var room BackendClientResponse
@ -1430,14 +1448,14 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
r.AddSession(session, room.Room.Session)
}
func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
msg := message.Message
session := client.GetSession()
if session == nil {
func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
session, ok := client.GetSession().(*ClientSession)
if session == nil || !ok {
// Client is not connected yet.
return
}
msg := message.Message
var recipient *ClientSession
var subject string
var clientData *MessageClientMessageData
@ -1484,10 +1502,10 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
// User is stopping to share his screen. Firefox doesn't properly clean
// up the peer connections in all cases, so make sure to stop publishing
// in the MCU.
go func(c *Client) {
go func(c HandlerClient) {
time.Sleep(cleanupScreenPublisherDelay)
session := c.GetSession()
if session == nil {
session, ok := c.GetSession().(*ClientSession)
if session == nil || !ok {
return
}
@ -1700,7 +1718,7 @@ func isAllowedToControl(session Session) bool {
return false
}
func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) {
msg := message.Control
session := client.GetSession()
if session == nil {
@ -1813,10 +1831,10 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
}
}
func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) {
func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) {
msg := message.Internal
session := client.GetSession()
if session == nil {
session, ok := client.GetSession().(*ClientSession)
if session == nil || !ok {
// Client is not connected yet.
return
} else if session.ClientType() != HelloClientTypeInternal {
@ -2030,7 +2048,7 @@ func isAllowedToUpdateTransientData(session Session) bool {
return false
}
func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) {
func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage) {
msg := message.TransientData
session := client.GetSession()
if session == nil {
@ -2070,17 +2088,17 @@ func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) {
}
}
func sendNotAllowed(session *ClientSession, message *ClientMessage, reason string) {
func sendNotAllowed(session Session, message *ClientMessage, reason string) {
response := message.NewErrorServerMessage(NewError("not_allowed", reason))
session.SendMessage(response)
}
func sendMcuClientNotFound(session *ClientSession, message *ClientMessage) {
func sendMcuClientNotFound(session Session, message *ClientMessage) {
response := message.NewErrorServerMessage(NewError("client_not_found", "No MCU client found to send message to."))
session.SendMessage(response)
}
func sendMcuProcessingFailed(session *ClientSession, message *ClientMessage) {
func sendMcuProcessingFailed(session Session, message *ClientMessage) {
response := message.NewErrorServerMessage(NewError("processing_failed", "Processing of the message failed, please check server logs."))
session.SendMessage(response)
}
@ -2295,7 +2313,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient McuClient
session.SendMessage(response_message)
}
func (h *Hub) processByeMsg(client *Client, message *ClientMessage) {
func (h *Hub) processByeMsg(client HandlerClient, message *ClientMessage) {
client.SendByeResponse(message)
if session := h.processUnregister(client); session != nil {
session.Close()
@ -2412,7 +2430,7 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
}(h)
}
func (h *Hub) OnLookupCountry(client *Client) string {
func (h *Hub) OnLookupCountry(client HandlerClient) string {
ip := net.ParseIP(client.RemoteAddr())
if ip == nil {
return noCountry
@ -2444,14 +2462,14 @@ func (h *Hub) OnLookupCountry(client *Client) string {
return country
}
func (h *Hub) OnClosed(client *Client) {
func (h *Hub) OnClosed(client HandlerClient) {
h.processUnregister(client)
}
func (h *Hub) OnMessageReceived(client *Client, data []byte) {
func (h *Hub) OnMessageReceived(client HandlerClient, data []byte) {
h.processMessage(client, data)
}
func (h *Hub) OnRTTReceived(client *Client, rtt time.Duration) {
func (h *Hub) OnRTTReceived(client HandlerClient, rtt time.Duration) {
// Ignore
}

View file

@ -53,18 +53,18 @@ func (c *ProxyClient) SetSession(session *ProxySession) {
c.session.Store(session)
}
func (c *ProxyClient) OnClosed(client *signaling.Client) {
func (c *ProxyClient) OnClosed(client signaling.HandlerClient) {
if session := c.GetSession(); session != nil {
session.MarkUsed()
}
c.proxy.clientClosed(&c.Client)
}
func (c *ProxyClient) OnMessageReceived(client *signaling.Client, data []byte) {
func (c *ProxyClient) OnMessageReceived(client signaling.HandlerClient, data []byte) {
c.proxy.processMessage(c, data)
}
func (c *ProxyClient) OnRTTReceived(client *signaling.Client, rtt time.Duration) {
func (c *ProxyClient) OnRTTReceived(client signaling.HandlerClient, rtt time.Duration) {
if session := c.GetSession(); session != nil {
session.MarkUsed()
}

View file

@ -86,6 +86,14 @@ func (s *DummySession) HasPermission(permission Permission) bool {
return false
}
func (s *DummySession) SendError(e *Error) bool {
return false
}
func (s *DummySession) SendMessage(message *ServerMessage) bool {
return false
}
func checkSession(t *testing.T, sessions RoomSessions, sessionId string, roomSessionId string) Session {
session := &DummySession{
publicId: sessionId,

View file

@ -72,4 +72,7 @@ type Session interface {
Close()
HasPermission(permission Permission) bool
SendError(e *Error) bool
SendMessage(message *ServerMessage) bool
}

View file

@ -311,12 +311,14 @@ func (c *TestClient) WaitForClientRemoved(ctx context.Context) error {
for {
found := false
for _, client := range c.hub.clients {
client.mu.Lock()
conn := client.conn
client.mu.Unlock()
if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() {
found = true
break
if cc, ok := client.(*Client); ok {
cc.mu.Lock()
conn := cc.conn
cc.mu.Unlock()
if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() {
found = true
break
}
}
}
if !found {

View file

@ -51,7 +51,7 @@ type VirtualSession struct {
options *AddSessionOptions
}
func GetVirtualSessionId(session *ClientSession, sessionId string) string {
func GetVirtualSessionId(session Session, sessionId string) string {
return session.PublicId() + "|" + sessionId
}
@ -163,7 +163,7 @@ func (s *VirtualSession) Close() {
s.CloseWithFeedback(nil, nil)
}
func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *ClientMessage) {
func (s *VirtualSession) CloseWithFeedback(session Session, message *ClientMessage) {
room := s.GetRoom()
s.session.RemoveVirtualSession(s)
removed := s.session.hub.removeSession(s)
@ -173,7 +173,7 @@ func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *Clie
s.session.events.UnregisterSessionListener(s.PublicId(), s.session.Backend(), s)
}
func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession, message *ClientMessage) {
func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, message *ClientMessage) {
ctx, cancel := context.WithTimeout(context.Background(), s.hub.backendTimeout)
defer cancel()
@ -321,3 +321,11 @@ func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) {
}
}
}
func (s *VirtualSession) SendError(e *Error) bool {
return s.session.SendError(e)
}
func (s *VirtualSession) SendMessage(message *ServerMessage) bool {
return s.session.SendMessage(message)
}