From 246844357256440f264b4cab3832eccaaa4d2db4 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 10:23:13 +0200 Subject: [PATCH] Add "HandlerClient" interface to support custom implementations. --- client.go | 106 ++++++++++++++++++++++-------- clientsession.go | 14 ++-- hub.go | 148 +++++++++++++++++++++++------------------- proxy/proxy_client.go | 6 +- roomsessions_test.go | 8 +++ session.go | 3 + testclient_test.go | 14 ++-- virtualsession.go | 14 +++- 8 files changed, 202 insertions(+), 111 deletions(-) diff --git a/client.go b/client.go index 0fe42e9..d918127 100644 --- a/client.go +++ b/client.go @@ -92,14 +92,32 @@ type WritableClientMessage interface { CloseAfterSend(session Session) bool } +type HandlerClient interface { + RemoteAddr() string + Country() string + UserAgent() string + IsConnected() bool + IsAuthenticated() bool + + GetSession() Session + SetSession(session Session) + + SendError(e *Error) bool + SendByeResponse(message *ClientMessage) bool + SendByeResponseWithReason(message *ClientMessage, reason string) bool + SendMessage(message WritableClientMessage) bool + + Close() +} + type ClientHandler interface { - OnClosed(*Client) - OnMessageReceived(*Client, []byte) - OnRTTReceived(*Client, time.Duration) + OnClosed(HandlerClient) + OnMessageReceived(HandlerClient, []byte) + OnRTTReceived(HandlerClient, time.Duration) } type ClientGeoIpHandler interface { - OnLookupCountry(*Client) string + OnLookupCountry(HandlerClient) string } type Client struct { @@ -111,7 +129,8 @@ type Client struct { country *string logRTT bool - session atomic.Pointer[ClientSession] + session atomic.Pointer[Session] + sessionId atomic.Pointer[string] mu sync.Mutex @@ -142,12 +161,16 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) { c.conn = conn c.addr = remoteAddress - c.handler = handler + c.SetHandler(handler) c.closer = NewCloser() c.messageChan = make(chan *bytes.Buffer, 16) c.messagesDone = make(chan struct{}) } +func (c *Client) SetHandler(handler ClientHandler) { + c.handler = handler +} + func (c *Client) IsConnected() bool { return c.closed.Load() == 0 } @@ -156,12 +179,39 @@ func (c *Client) IsAuthenticated() bool { return c.GetSession() != nil } -func (c *Client) GetSession() *ClientSession { - return c.session.Load() +func (c *Client) GetSession() Session { + session := c.session.Load() + if session == nil { + return nil + } + + return *session } -func (c *Client) SetSession(session *ClientSession) { - c.session.Store(session) +func (c *Client) SetSession(session Session) { + if session == nil { + c.session.Store(nil) + } else { + c.session.Store(&session) + } +} + +func (c *Client) SetSessionId(sessionId string) { + c.sessionId.Store(&sessionId) +} + +func (c *Client) GetSessionId() string { + sessionId := c.sessionId.Load() + if sessionId == nil { + session := c.GetSession() + if session == nil { + return "" + } + + return session.PublicId() + } + + return *sessionId } func (c *Client) RemoteAddr() string { @@ -234,12 +284,14 @@ func (c *Client) SendByeResponse(message *ClientMessage) bool { func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string) bool { response := &ServerMessage{ Type: "bye", - Bye: &ByeServerMessage{}, } if message != nil { response.Id = message.Id } if reason != "" { + if response.Bye == nil { + response.Bye = &ByeServerMessage{} + } response.Bye.Reason = reason } return c.SendMessage(response) @@ -277,8 +329,8 @@ func (c *Client) ReadPump() { rtt := now.Sub(time.Unix(0, ts)) if c.logRTT { rtt_ms := rtt.Nanoseconds() / time.Millisecond.Nanoseconds() - if session := c.GetSession(); session != nil { - log.Printf("Client %s has RTT of %d ms (%s)", session.PublicId(), rtt_ms, rtt) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Client %s has RTT of %d ms (%s)", sessionId, rtt_ms, rtt) } else { log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) } @@ -296,8 +348,8 @@ func (c *Client) ReadPump() { websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - if session := c.GetSession(); session != nil { - log.Printf("Error reading from client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Error reading from client %s: %v", sessionId, err) } else { log.Printf("Error reading from %s: %v", addr, err) } @@ -306,8 +358,8 @@ func (c *Client) ReadPump() { } if messageType != websocket.TextMessage { - if session := c.GetSession(); session != nil { - log.Printf("Unsupported message type %v from client %s", messageType, session.PublicId()) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Unsupported message type %v from client %s", messageType, sessionId) } else { log.Printf("Unsupported message type %v from %s", messageType, addr) } @@ -319,8 +371,8 @@ func (c *Client) ReadPump() { decodeBuffer.Reset() if _, err := decodeBuffer.ReadFrom(reader); err != nil { bufferPool.Put(decodeBuffer) - if session := c.GetSession(); session != nil { - log.Printf("Error reading message from client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Error reading message from client %s: %v", sessionId, err) } else { log.Printf("Error reading message from %s: %v", addr, err) } @@ -373,8 +425,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool { return false } - if session := c.GetSession(); session != nil { - log.Printf("Could not send message %+v to client %s: %v", message, session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Could not send message %+v to client %s: %v", message, sessionId, err) } else { log.Printf("Could not send message %+v to %s: %v", message, c.RemoteAddr(), err) } @@ -386,8 +438,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool { close: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Could not send close message to client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Could not send close message to client %s: %v", sessionId, err) } else { log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) } @@ -413,8 +465,8 @@ func (c *Client) writeError(e error) bool { // nolint closeData := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, e.Error()) c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Could not send close message to client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Could not send close message to client %s: %v", sessionId, err) } else { log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) } @@ -462,8 +514,8 @@ func (c *Client) sendPing() bool { msg := strconv.FormatInt(now, 10) c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Could not send ping to client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Could not send ping to client %s: %v", sessionId, err) } else { log.Printf("Could not send ping to %s: %v", c.RemoteAddr(), err) } diff --git a/clientsession.go b/clientsession.go index ff8a488..51bbf5e 100644 --- a/clientsession.go +++ b/clientsession.go @@ -67,7 +67,7 @@ type ClientSession struct { mu sync.Mutex - client *Client + client HandlerClient room atomic.Pointer[Room] roomJoinTime atomic.Int64 roomSessionId string @@ -500,14 +500,14 @@ func (s *ClientSession) doUnsubscribeRoomEvents(notify bool) { s.roomSessionId = "" } -func (s *ClientSession) ClearClient(client *Client) { +func (s *ClientSession) ClearClient(client HandlerClient) { s.mu.Lock() defer s.mu.Unlock() s.clearClientLocked(client) } -func (s *ClientSession) clearClientLocked(client *Client) { +func (s *ClientSession) clearClientLocked(client HandlerClient) { if s.client == nil { return } else if client != nil && s.client != client { @@ -520,18 +520,18 @@ func (s *ClientSession) clearClientLocked(client *Client) { prevClient.SetSession(nil) } -func (s *ClientSession) GetClient() *Client { +func (s *ClientSession) GetClient() HandlerClient { s.mu.Lock() defer s.mu.Unlock() return s.getClientUnlocked() } -func (s *ClientSession) getClientUnlocked() *Client { +func (s *ClientSession) getClientUnlocked() HandlerClient { return s.client } -func (s *ClientSession) SetClient(client *Client) *Client { +func (s *ClientSession) SetClient(client HandlerClient) HandlerClient { if client == nil { panic("Use ClearClient to set the client to nil") } @@ -1341,7 +1341,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage { } } -func (s *ClientSession) NotifySessionResumed(client *Client) { +func (s *ClientSession) NotifySessionResumed(client HandlerClient) { s.mu.Lock() if len(s.pendingClientMessages) == 0 { s.mu.Unlock() diff --git a/hub.go b/hub.go index 53f6ee4..e92b566 100644 --- a/hub.go +++ b/hub.go @@ -135,7 +135,7 @@ type Hub struct { ru sync.RWMutex sid atomic.Uint64 - clients map[uint64]*Client + clients map[uint64]HandlerClient sessions map[uint64]Session rooms map[string]*Room @@ -153,7 +153,7 @@ type Hub struct { expiredSessions map[Session]time.Time anonymousSessions map[*ClientSession]time.Time - expectHelloClients map[*Client]time.Time + expectHelloClients map[HandlerClient]time.Time dialoutSessions map[*ClientSession]bool backendTimeout time.Duration @@ -324,7 +324,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer roomInCall: make(chan *BackendServerRoomRequest), roomParticipants: make(chan *BackendServerRoomRequest), - clients: make(map[uint64]*Client), + clients: make(map[uint64]HandlerClient), sessions: make(map[uint64]Session), rooms: make(map[string]*Room), @@ -341,7 +341,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer expiredSessions: make(map[Session]time.Time), anonymousSessions: make(map[*ClientSession]time.Time), - expectHelloClients: make(map[*Client]time.Time), + expectHelloClients: make(map[HandlerClient]time.Time), dialoutSessions: make(map[*ClientSession]bool), backendTimeout: backendTimeout, @@ -690,15 +690,13 @@ func (h *Hub) startWaitAnonymousSessionRoomLocked(session *ClientSession) { h.anonymousSessions[session] = now.Add(anonmyousJoinRoomTimeout) } -func (h *Hub) startExpectHello(client *Client) { +func (h *Hub) startExpectHello(client HandlerClient) { h.mu.Lock() defer h.mu.Unlock() if !client.IsConnected() { return } - client.mu.Lock() - defer client.mu.Unlock() if client.IsAuthenticated() { return } @@ -708,12 +706,12 @@ func (h *Hub) startExpectHello(client *Client) { h.expectHelloClients[client] = now.Add(initialHelloTimeout) } -func (h *Hub) processNewClient(client *Client) { +func (h *Hub) processNewClient(client HandlerClient) { h.startExpectHello(client) h.sendWelcome(client) } -func (h *Hub) sendWelcome(client *Client) { +func (h *Hub) sendWelcome(client HandlerClient) { client.SendMessage(h.getWelcomeMessage()) } @@ -730,17 +728,24 @@ func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData { return sessionIdData } -func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *Backend, auth *BackendClientResponse) { - if !client.IsConnected() { +func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *Backend, auth *BackendClientResponse) { + if !c.IsConnected() { // Client disconnected while waiting for "hello" response. return } if auth.Type == "error" { - client.SendMessage(message.NewErrorServerMessage(auth.Error)) + c.SendMessage(message.NewErrorServerMessage(auth.Error)) return } else if auth.Type != "auth" { - client.SendMessage(message.NewErrorServerMessage(UserAuthFailed)) + c.SendMessage(message.NewErrorServerMessage(UserAuthFailed)) + return + } + + client, ok := c.(*Client) + if !ok { + log.Printf("Can't register non-client %T", c) + client.SendMessage(message.NewWrappedErrorServerMessage(errors.New("can't register non-client"))) return } @@ -844,7 +849,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B h.sendHelloResponse(session, message) } -func (h *Hub) processUnregister(client *Client) *ClientSession { +func (h *Hub) processUnregister(client HandlerClient) Session { session := client.GetSession() h.mu.Lock() @@ -857,14 +862,18 @@ func (h *Hub) processUnregister(client *Client) *ClientSession { h.mu.Unlock() if session != nil { log.Printf("Unregister %s (private=%s)", session.PublicId(), session.PrivateId()) - session.ClearClient(client) + if c, ok := client.(*Client); ok { + if cs, ok := session.(*ClientSession); ok { + cs.ClearClient(c) + } + } } client.Close() return session } -func (h *Hub) processMessage(client *Client, data []byte) { +func (h *Hub) processMessage(client HandlerClient, data []byte) { var message ClientMessage if err := message.UnmarshalJSON(data); err != nil { if session := client.GetSession(); session != nil { @@ -944,7 +953,7 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage) return session.SendMessage(response) } -func (h *Hub) processHello(client *Client, message *ClientMessage) { +func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { resumeId := message.Hello.ResumeId if resumeId != "" { data := h.decodeSessionId(resumeId, privateSessionName) @@ -1013,7 +1022,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) { } } -func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) { +func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) { url := message.Hello.Auth.parsedUrl backend := h.backend.GetBackend(url) if backend == nil { @@ -1035,7 +1044,7 @@ func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, return backend, &auth, nil } -func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) { +func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) { url := message.Hello.Auth.parsedUrl backend := h.backend.GetBackend(url) if backend == nil { @@ -1141,11 +1150,11 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, return backend, auth, nil } -func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { +func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) { // Make sure the client must send another "hello" in case of errors. defer h.startExpectHello(client) - var authFunc func(*Client, *ClientMessage) (*Backend, *BackendClientResponse, error) + var authFunc func(HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error) switch message.Hello.Version { case HelloVersionV1: // Auth information contains a ticket that must be validated against the @@ -1172,7 +1181,7 @@ func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { h.processRegister(client, message, backend, auth) } -func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) { +func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage) { defer h.startExpectHello(client) if len(h.internalClientsSecret) == 0 { client.SendMessage(message.NewErrorServerMessage(InvalidClientType)) @@ -1261,8 +1270,12 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo return session.SendMessage(response) } -func (h *Hub) processRoom(client *Client, message *ClientMessage) { - session := client.GetSession() +func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) { + session, ok := client.GetSession().(*ClientSession) + if !ok { + return + } + roomId := message.Room.RoomId if roomId == "" { if session == nil { @@ -1281,29 +1294,34 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { return } - if session != nil { - if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) { - // Session already is in that room, no action needed. - roomSessionId := message.Room.SessionId - if roomSessionId == "" { - // TODO(jojo): Better make the session id required in the request. - log.Printf("User did not send a room session id, assuming session %s", session.PublicId()) - roomSessionId = session.PublicId() - } + if session == nil { + session.SendMessage(message.NewErrorServerMessage( + NewError("not_authenticated", "Need to authenticate before joining rooms."), + )) + return + } - if err := session.UpdateRoomSessionId(roomSessionId); err != nil { - log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err) - } - session.SendMessage(message.NewErrorServerMessage( - NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{ - Room: &RoomServerMessage{ - RoomId: room.id, - Properties: room.properties, - }, - }), - )) - return + if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) { + // Session already is in that room, no action needed. + roomSessionId := message.Room.SessionId + if roomSessionId == "" { + // TODO(jojo): Better make the session id required in the request. + log.Printf("User did not send a room session id, assuming session %s", session.PublicId()) + roomSessionId = session.PublicId() } + + if err := session.UpdateRoomSessionId(roomSessionId); err != nil { + log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err) + } + session.SendMessage(message.NewErrorServerMessage( + NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{ + Room: &RoomServerMessage{ + RoomId: room.id, + Properties: room.properties, + }, + }), + )) + return } var room BackendClientResponse @@ -1430,14 +1448,14 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro r.AddSession(session, room.Room.Session) } -func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { - msg := message.Message - session := client.GetSession() - if session == nil { +func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) { + session, ok := client.GetSession().(*ClientSession) + if session == nil || !ok { // Client is not connected yet. return } + msg := message.Message var recipient *ClientSession var subject string var clientData *MessageClientMessageData @@ -1484,10 +1502,10 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { // User is stopping to share his screen. Firefox doesn't properly clean // up the peer connections in all cases, so make sure to stop publishing // in the MCU. - go func(c *Client) { + go func(c HandlerClient) { time.Sleep(cleanupScreenPublisherDelay) - session := c.GetSession() - if session == nil { + session, ok := c.GetSession().(*ClientSession) + if session == nil || !ok { return } @@ -1700,7 +1718,7 @@ func isAllowedToControl(session Session) bool { return false } -func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { +func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) { msg := message.Control session := client.GetSession() if session == nil { @@ -1813,10 +1831,10 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { } } -func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { +func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) { msg := message.Internal - session := client.GetSession() - if session == nil { + session, ok := client.GetSession().(*ClientSession) + if session == nil || !ok { // Client is not connected yet. return } else if session.ClientType() != HelloClientTypeInternal { @@ -2030,7 +2048,7 @@ func isAllowedToUpdateTransientData(session Session) bool { return false } -func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) { +func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage) { msg := message.TransientData session := client.GetSession() if session == nil { @@ -2070,17 +2088,17 @@ func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) { } } -func sendNotAllowed(session *ClientSession, message *ClientMessage, reason string) { +func sendNotAllowed(session Session, message *ClientMessage, reason string) { response := message.NewErrorServerMessage(NewError("not_allowed", reason)) session.SendMessage(response) } -func sendMcuClientNotFound(session *ClientSession, message *ClientMessage) { +func sendMcuClientNotFound(session Session, message *ClientMessage) { response := message.NewErrorServerMessage(NewError("client_not_found", "No MCU client found to send message to.")) session.SendMessage(response) } -func sendMcuProcessingFailed(session *ClientSession, message *ClientMessage) { +func sendMcuProcessingFailed(session Session, message *ClientMessage) { response := message.NewErrorServerMessage(NewError("processing_failed", "Processing of the message failed, please check server logs.")) session.SendMessage(response) } @@ -2295,7 +2313,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient McuClient session.SendMessage(response_message) } -func (h *Hub) processByeMsg(client *Client, message *ClientMessage) { +func (h *Hub) processByeMsg(client HandlerClient, message *ClientMessage) { client.SendByeResponse(message) if session := h.processUnregister(client); session != nil { session.Close() @@ -2412,7 +2430,7 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { }(h) } -func (h *Hub) OnLookupCountry(client *Client) string { +func (h *Hub) OnLookupCountry(client HandlerClient) string { ip := net.ParseIP(client.RemoteAddr()) if ip == nil { return noCountry @@ -2444,14 +2462,14 @@ func (h *Hub) OnLookupCountry(client *Client) string { return country } -func (h *Hub) OnClosed(client *Client) { +func (h *Hub) OnClosed(client HandlerClient) { h.processUnregister(client) } -func (h *Hub) OnMessageReceived(client *Client, data []byte) { +func (h *Hub) OnMessageReceived(client HandlerClient, data []byte) { h.processMessage(client, data) } -func (h *Hub) OnRTTReceived(client *Client, rtt time.Duration) { +func (h *Hub) OnRTTReceived(client HandlerClient, rtt time.Duration) { // Ignore } diff --git a/proxy/proxy_client.go b/proxy/proxy_client.go index dde4de8..cee7328 100644 --- a/proxy/proxy_client.go +++ b/proxy/proxy_client.go @@ -53,18 +53,18 @@ func (c *ProxyClient) SetSession(session *ProxySession) { c.session.Store(session) } -func (c *ProxyClient) OnClosed(client *signaling.Client) { +func (c *ProxyClient) OnClosed(client signaling.HandlerClient) { if session := c.GetSession(); session != nil { session.MarkUsed() } c.proxy.clientClosed(&c.Client) } -func (c *ProxyClient) OnMessageReceived(client *signaling.Client, data []byte) { +func (c *ProxyClient) OnMessageReceived(client signaling.HandlerClient, data []byte) { c.proxy.processMessage(c, data) } -func (c *ProxyClient) OnRTTReceived(client *signaling.Client, rtt time.Duration) { +func (c *ProxyClient) OnRTTReceived(client signaling.HandlerClient, rtt time.Duration) { if session := c.GetSession(); session != nil { session.MarkUsed() } diff --git a/roomsessions_test.go b/roomsessions_test.go index 5a8ffe0..805fa5b 100644 --- a/roomsessions_test.go +++ b/roomsessions_test.go @@ -86,6 +86,14 @@ func (s *DummySession) HasPermission(permission Permission) bool { return false } +func (s *DummySession) SendError(e *Error) bool { + return false +} + +func (s *DummySession) SendMessage(message *ServerMessage) bool { + return false +} + func checkSession(t *testing.T, sessions RoomSessions, sessionId string, roomSessionId string) Session { session := &DummySession{ publicId: sessionId, diff --git a/session.go b/session.go index 79286a3..5fd36c2 100644 --- a/session.go +++ b/session.go @@ -72,4 +72,7 @@ type Session interface { Close() HasPermission(permission Permission) bool + + SendError(e *Error) bool + SendMessage(message *ServerMessage) bool } diff --git a/testclient_test.go b/testclient_test.go index 23f1d63..9fcc5fd 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -311,12 +311,14 @@ func (c *TestClient) WaitForClientRemoved(ctx context.Context) error { for { found := false for _, client := range c.hub.clients { - client.mu.Lock() - conn := client.conn - client.mu.Unlock() - if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() { - found = true - break + if cc, ok := client.(*Client); ok { + cc.mu.Lock() + conn := cc.conn + cc.mu.Unlock() + if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() { + found = true + break + } } } if !found { diff --git a/virtualsession.go b/virtualsession.go index 6288396..7d17e6e 100644 --- a/virtualsession.go +++ b/virtualsession.go @@ -51,7 +51,7 @@ type VirtualSession struct { options *AddSessionOptions } -func GetVirtualSessionId(session *ClientSession, sessionId string) string { +func GetVirtualSessionId(session Session, sessionId string) string { return session.PublicId() + "|" + sessionId } @@ -163,7 +163,7 @@ func (s *VirtualSession) Close() { s.CloseWithFeedback(nil, nil) } -func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *ClientMessage) { +func (s *VirtualSession) CloseWithFeedback(session Session, message *ClientMessage) { room := s.GetRoom() s.session.RemoveVirtualSession(s) removed := s.session.hub.removeSession(s) @@ -173,7 +173,7 @@ func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *Clie s.session.events.UnregisterSessionListener(s.PublicId(), s.session.Backend(), s) } -func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession, message *ClientMessage) { +func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, message *ClientMessage) { ctx, cancel := context.WithTimeout(context.Background(), s.hub.backendTimeout) defer cancel() @@ -321,3 +321,11 @@ func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) { } } } + +func (s *VirtualSession) SendError(e *Error) bool { + return s.session.SendError(e) +} + +func (s *VirtualSession) SendMessage(message *ServerMessage) bool { + return s.session.SendMessage(message) +}