From 3721fb131f25ac65601a8562c7f24ce94a4fb01c Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 08:51:34 +0200 Subject: [PATCH 1/4] Don't include empty "auth" field in hello client messages. --- api_signaling.go | 4 ++-- api_signaling_test.go | 32 ++++++++++++++++---------------- client/main.go | 2 +- testclient_test.go | 2 +- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/api_signaling.go b/api_signaling.go index 98523f3..0625bf1 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -387,7 +387,7 @@ type HelloClientMessage struct { Features []string `json:"features,omitempty"` // The authentication credentials. - Auth HelloClientMessageAuth `json:"auth"` + Auth *HelloClientMessageAuth `json:"auth,omitempty"` } func (m *HelloClientMessage) CheckValid() error { @@ -395,7 +395,7 @@ func (m *HelloClientMessage) CheckValid() error { return InvalidHelloVersion } if m.ResumeId == "" { - if m.Auth.Params == nil || len(*m.Auth.Params) == 0 { + if m.Auth == nil || m.Auth.Params == nil || len(*m.Auth.Params) == 0 { return fmt.Errorf("params missing") } if m.Auth.Type == "" { diff --git a/api_signaling_test.go b/api_signaling_test.go index 94e54c7..cc93ba7 100644 --- a/api_signaling_test.go +++ b/api_signaling_test.go @@ -95,14 +95,14 @@ func TestHelloClientMessage(t *testing.T) { // Hello version 1 &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Url: "https://domain.invalid", }, }, &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Type: "client", Params: &json.RawMessage{'{', '}'}, Url: "https://domain.invalid", @@ -110,7 +110,7 @@ func TestHelloClientMessage(t *testing.T) { }, &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Type: "internal", Params: (*json.RawMessage)(&internalAuthParams), }, @@ -122,14 +122,14 @@ func TestHelloClientMessage(t *testing.T) { // Hello version 2 &HelloClientMessage{ Version: HelloVersionV2, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: (*json.RawMessage)(&tokenAuthParams), Url: "https://domain.invalid", }, }, &HelloClientMessage{ Version: HelloVersionV2, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Type: "client", Params: (*json.RawMessage)(&tokenAuthParams), Url: "https://domain.invalid", @@ -147,40 +147,40 @@ func TestHelloClientMessage(t *testing.T) { &HelloClientMessage{Version: HelloVersionV1}, &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Type: "invalid-type", }, }, &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Url: "https://domain.invalid", }, }, &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, }, }, &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: &json.RawMessage{'{', '}'}, Url: "invalid-url", }, }, &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Type: "internal", Params: &json.RawMessage{'{', '}'}, }, }, &HelloClientMessage{ Version: HelloVersionV1, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Type: "internal", Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON. }, @@ -188,33 +188,33 @@ func TestHelloClientMessage(t *testing.T) { // Hello version 2 &HelloClientMessage{ Version: HelloVersionV2, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Url: "https://domain.invalid", }, }, &HelloClientMessage{ Version: HelloVersionV2, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: (*json.RawMessage)(&tokenAuthParams), }, }, &HelloClientMessage{ Version: HelloVersionV2, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: (*json.RawMessage)(&tokenAuthParams), Url: "invalid-url", }, }, &HelloClientMessage{ Version: HelloVersionV2, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: (*json.RawMessage)(&internalAuthParams), Url: "https://domain.invalid", }, }, &HelloClientMessage{ Version: HelloVersionV2, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON. Url: "https://domain.invalid", }, diff --git a/client/main.go b/client/main.go index cd98f59..1760167 100644 --- a/client/main.go +++ b/client/main.go @@ -601,7 +601,7 @@ func main() { Type: "hello", Hello: &signaling.HelloClientMessage{ Version: signaling.HelloVersionV1, - Auth: signaling.HelloClientMessageAuth{ + Auth: &signaling.HelloClientMessageAuth{ Url: backendUrl + "/auth", Params: &json.RawMessage{'{', '}'}, }, diff --git a/testclient_test.go b/testclient_test.go index e2b1106..23f1d63 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -493,7 +493,7 @@ func (c *TestClient) SendHelloParams(url string, version string, clientType stri Hello: &HelloClientMessage{ Version: version, Features: features, - Auth: HelloClientMessageAuth{ + Auth: &HelloClientMessageAuth{ Type: clientType, Url: url, Params: (*json.RawMessage)(&data), From 246844357256440f264b4cab3832eccaaa4d2db4 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 10:23:13 +0200 Subject: [PATCH 2/4] Add "HandlerClient" interface to support custom implementations. --- client.go | 106 ++++++++++++++++++++++-------- clientsession.go | 14 ++-- hub.go | 148 +++++++++++++++++++++++------------------- proxy/proxy_client.go | 6 +- roomsessions_test.go | 8 +++ session.go | 3 + testclient_test.go | 14 ++-- virtualsession.go | 14 +++- 8 files changed, 202 insertions(+), 111 deletions(-) diff --git a/client.go b/client.go index 0fe42e9..d918127 100644 --- a/client.go +++ b/client.go @@ -92,14 +92,32 @@ type WritableClientMessage interface { CloseAfterSend(session Session) bool } +type HandlerClient interface { + RemoteAddr() string + Country() string + UserAgent() string + IsConnected() bool + IsAuthenticated() bool + + GetSession() Session + SetSession(session Session) + + SendError(e *Error) bool + SendByeResponse(message *ClientMessage) bool + SendByeResponseWithReason(message *ClientMessage, reason string) bool + SendMessage(message WritableClientMessage) bool + + Close() +} + type ClientHandler interface { - OnClosed(*Client) - OnMessageReceived(*Client, []byte) - OnRTTReceived(*Client, time.Duration) + OnClosed(HandlerClient) + OnMessageReceived(HandlerClient, []byte) + OnRTTReceived(HandlerClient, time.Duration) } type ClientGeoIpHandler interface { - OnLookupCountry(*Client) string + OnLookupCountry(HandlerClient) string } type Client struct { @@ -111,7 +129,8 @@ type Client struct { country *string logRTT bool - session atomic.Pointer[ClientSession] + session atomic.Pointer[Session] + sessionId atomic.Pointer[string] mu sync.Mutex @@ -142,12 +161,16 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) { c.conn = conn c.addr = remoteAddress - c.handler = handler + c.SetHandler(handler) c.closer = NewCloser() c.messageChan = make(chan *bytes.Buffer, 16) c.messagesDone = make(chan struct{}) } +func (c *Client) SetHandler(handler ClientHandler) { + c.handler = handler +} + func (c *Client) IsConnected() bool { return c.closed.Load() == 0 } @@ -156,12 +179,39 @@ func (c *Client) IsAuthenticated() bool { return c.GetSession() != nil } -func (c *Client) GetSession() *ClientSession { - return c.session.Load() +func (c *Client) GetSession() Session { + session := c.session.Load() + if session == nil { + return nil + } + + return *session } -func (c *Client) SetSession(session *ClientSession) { - c.session.Store(session) +func (c *Client) SetSession(session Session) { + if session == nil { + c.session.Store(nil) + } else { + c.session.Store(&session) + } +} + +func (c *Client) SetSessionId(sessionId string) { + c.sessionId.Store(&sessionId) +} + +func (c *Client) GetSessionId() string { + sessionId := c.sessionId.Load() + if sessionId == nil { + session := c.GetSession() + if session == nil { + return "" + } + + return session.PublicId() + } + + return *sessionId } func (c *Client) RemoteAddr() string { @@ -234,12 +284,14 @@ func (c *Client) SendByeResponse(message *ClientMessage) bool { func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string) bool { response := &ServerMessage{ Type: "bye", - Bye: &ByeServerMessage{}, } if message != nil { response.Id = message.Id } if reason != "" { + if response.Bye == nil { + response.Bye = &ByeServerMessage{} + } response.Bye.Reason = reason } return c.SendMessage(response) @@ -277,8 +329,8 @@ func (c *Client) ReadPump() { rtt := now.Sub(time.Unix(0, ts)) if c.logRTT { rtt_ms := rtt.Nanoseconds() / time.Millisecond.Nanoseconds() - if session := c.GetSession(); session != nil { - log.Printf("Client %s has RTT of %d ms (%s)", session.PublicId(), rtt_ms, rtt) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Client %s has RTT of %d ms (%s)", sessionId, rtt_ms, rtt) } else { log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt) } @@ -296,8 +348,8 @@ func (c *Client) ReadPump() { websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - if session := c.GetSession(); session != nil { - log.Printf("Error reading from client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Error reading from client %s: %v", sessionId, err) } else { log.Printf("Error reading from %s: %v", addr, err) } @@ -306,8 +358,8 @@ func (c *Client) ReadPump() { } if messageType != websocket.TextMessage { - if session := c.GetSession(); session != nil { - log.Printf("Unsupported message type %v from client %s", messageType, session.PublicId()) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Unsupported message type %v from client %s", messageType, sessionId) } else { log.Printf("Unsupported message type %v from %s", messageType, addr) } @@ -319,8 +371,8 @@ func (c *Client) ReadPump() { decodeBuffer.Reset() if _, err := decodeBuffer.ReadFrom(reader); err != nil { bufferPool.Put(decodeBuffer) - if session := c.GetSession(); session != nil { - log.Printf("Error reading message from client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Error reading message from client %s: %v", sessionId, err) } else { log.Printf("Error reading message from %s: %v", addr, err) } @@ -373,8 +425,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool { return false } - if session := c.GetSession(); session != nil { - log.Printf("Could not send message %+v to client %s: %v", message, session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Could not send message %+v to client %s: %v", message, sessionId, err) } else { log.Printf("Could not send message %+v to %s: %v", message, c.RemoteAddr(), err) } @@ -386,8 +438,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool { close: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Could not send close message to client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Could not send close message to client %s: %v", sessionId, err) } else { log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) } @@ -413,8 +465,8 @@ func (c *Client) writeError(e error) bool { // nolint closeData := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, e.Error()) c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Could not send close message to client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Could not send close message to client %s: %v", sessionId, err) } else { log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err) } @@ -462,8 +514,8 @@ func (c *Client) sendPing() bool { msg := strconv.FormatInt(now, 10) c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil { - if session := c.GetSession(); session != nil { - log.Printf("Could not send ping to client %s: %v", session.PublicId(), err) + if sessionId := c.GetSessionId(); sessionId != "" { + log.Printf("Could not send ping to client %s: %v", sessionId, err) } else { log.Printf("Could not send ping to %s: %v", c.RemoteAddr(), err) } diff --git a/clientsession.go b/clientsession.go index ff8a488..51bbf5e 100644 --- a/clientsession.go +++ b/clientsession.go @@ -67,7 +67,7 @@ type ClientSession struct { mu sync.Mutex - client *Client + client HandlerClient room atomic.Pointer[Room] roomJoinTime atomic.Int64 roomSessionId string @@ -500,14 +500,14 @@ func (s *ClientSession) doUnsubscribeRoomEvents(notify bool) { s.roomSessionId = "" } -func (s *ClientSession) ClearClient(client *Client) { +func (s *ClientSession) ClearClient(client HandlerClient) { s.mu.Lock() defer s.mu.Unlock() s.clearClientLocked(client) } -func (s *ClientSession) clearClientLocked(client *Client) { +func (s *ClientSession) clearClientLocked(client HandlerClient) { if s.client == nil { return } else if client != nil && s.client != client { @@ -520,18 +520,18 @@ func (s *ClientSession) clearClientLocked(client *Client) { prevClient.SetSession(nil) } -func (s *ClientSession) GetClient() *Client { +func (s *ClientSession) GetClient() HandlerClient { s.mu.Lock() defer s.mu.Unlock() return s.getClientUnlocked() } -func (s *ClientSession) getClientUnlocked() *Client { +func (s *ClientSession) getClientUnlocked() HandlerClient { return s.client } -func (s *ClientSession) SetClient(client *Client) *Client { +func (s *ClientSession) SetClient(client HandlerClient) HandlerClient { if client == nil { panic("Use ClearClient to set the client to nil") } @@ -1341,7 +1341,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage { } } -func (s *ClientSession) NotifySessionResumed(client *Client) { +func (s *ClientSession) NotifySessionResumed(client HandlerClient) { s.mu.Lock() if len(s.pendingClientMessages) == 0 { s.mu.Unlock() diff --git a/hub.go b/hub.go index 53f6ee4..e92b566 100644 --- a/hub.go +++ b/hub.go @@ -135,7 +135,7 @@ type Hub struct { ru sync.RWMutex sid atomic.Uint64 - clients map[uint64]*Client + clients map[uint64]HandlerClient sessions map[uint64]Session rooms map[string]*Room @@ -153,7 +153,7 @@ type Hub struct { expiredSessions map[Session]time.Time anonymousSessions map[*ClientSession]time.Time - expectHelloClients map[*Client]time.Time + expectHelloClients map[HandlerClient]time.Time dialoutSessions map[*ClientSession]bool backendTimeout time.Duration @@ -324,7 +324,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer roomInCall: make(chan *BackendServerRoomRequest), roomParticipants: make(chan *BackendServerRoomRequest), - clients: make(map[uint64]*Client), + clients: make(map[uint64]HandlerClient), sessions: make(map[uint64]Session), rooms: make(map[string]*Room), @@ -341,7 +341,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer expiredSessions: make(map[Session]time.Time), anonymousSessions: make(map[*ClientSession]time.Time), - expectHelloClients: make(map[*Client]time.Time), + expectHelloClients: make(map[HandlerClient]time.Time), dialoutSessions: make(map[*ClientSession]bool), backendTimeout: backendTimeout, @@ -690,15 +690,13 @@ func (h *Hub) startWaitAnonymousSessionRoomLocked(session *ClientSession) { h.anonymousSessions[session] = now.Add(anonmyousJoinRoomTimeout) } -func (h *Hub) startExpectHello(client *Client) { +func (h *Hub) startExpectHello(client HandlerClient) { h.mu.Lock() defer h.mu.Unlock() if !client.IsConnected() { return } - client.mu.Lock() - defer client.mu.Unlock() if client.IsAuthenticated() { return } @@ -708,12 +706,12 @@ func (h *Hub) startExpectHello(client *Client) { h.expectHelloClients[client] = now.Add(initialHelloTimeout) } -func (h *Hub) processNewClient(client *Client) { +func (h *Hub) processNewClient(client HandlerClient) { h.startExpectHello(client) h.sendWelcome(client) } -func (h *Hub) sendWelcome(client *Client) { +func (h *Hub) sendWelcome(client HandlerClient) { client.SendMessage(h.getWelcomeMessage()) } @@ -730,17 +728,24 @@ func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData { return sessionIdData } -func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *Backend, auth *BackendClientResponse) { - if !client.IsConnected() { +func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *Backend, auth *BackendClientResponse) { + if !c.IsConnected() { // Client disconnected while waiting for "hello" response. return } if auth.Type == "error" { - client.SendMessage(message.NewErrorServerMessage(auth.Error)) + c.SendMessage(message.NewErrorServerMessage(auth.Error)) return } else if auth.Type != "auth" { - client.SendMessage(message.NewErrorServerMessage(UserAuthFailed)) + c.SendMessage(message.NewErrorServerMessage(UserAuthFailed)) + return + } + + client, ok := c.(*Client) + if !ok { + log.Printf("Can't register non-client %T", c) + client.SendMessage(message.NewWrappedErrorServerMessage(errors.New("can't register non-client"))) return } @@ -844,7 +849,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B h.sendHelloResponse(session, message) } -func (h *Hub) processUnregister(client *Client) *ClientSession { +func (h *Hub) processUnregister(client HandlerClient) Session { session := client.GetSession() h.mu.Lock() @@ -857,14 +862,18 @@ func (h *Hub) processUnregister(client *Client) *ClientSession { h.mu.Unlock() if session != nil { log.Printf("Unregister %s (private=%s)", session.PublicId(), session.PrivateId()) - session.ClearClient(client) + if c, ok := client.(*Client); ok { + if cs, ok := session.(*ClientSession); ok { + cs.ClearClient(c) + } + } } client.Close() return session } -func (h *Hub) processMessage(client *Client, data []byte) { +func (h *Hub) processMessage(client HandlerClient, data []byte) { var message ClientMessage if err := message.UnmarshalJSON(data); err != nil { if session := client.GetSession(); session != nil { @@ -944,7 +953,7 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage) return session.SendMessage(response) } -func (h *Hub) processHello(client *Client, message *ClientMessage) { +func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { resumeId := message.Hello.ResumeId if resumeId != "" { data := h.decodeSessionId(resumeId, privateSessionName) @@ -1013,7 +1022,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) { } } -func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) { +func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) { url := message.Hello.Auth.parsedUrl backend := h.backend.GetBackend(url) if backend == nil { @@ -1035,7 +1044,7 @@ func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, return backend, &auth, nil } -func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) { +func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) { url := message.Hello.Auth.parsedUrl backend := h.backend.GetBackend(url) if backend == nil { @@ -1141,11 +1150,11 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, return backend, auth, nil } -func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { +func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) { // Make sure the client must send another "hello" in case of errors. defer h.startExpectHello(client) - var authFunc func(*Client, *ClientMessage) (*Backend, *BackendClientResponse, error) + var authFunc func(HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error) switch message.Hello.Version { case HelloVersionV1: // Auth information contains a ticket that must be validated against the @@ -1172,7 +1181,7 @@ func (h *Hub) processHelloClient(client *Client, message *ClientMessage) { h.processRegister(client, message, backend, auth) } -func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) { +func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage) { defer h.startExpectHello(client) if len(h.internalClientsSecret) == 0 { client.SendMessage(message.NewErrorServerMessage(InvalidClientType)) @@ -1261,8 +1270,12 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo return session.SendMessage(response) } -func (h *Hub) processRoom(client *Client, message *ClientMessage) { - session := client.GetSession() +func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) { + session, ok := client.GetSession().(*ClientSession) + if !ok { + return + } + roomId := message.Room.RoomId if roomId == "" { if session == nil { @@ -1281,29 +1294,34 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) { return } - if session != nil { - if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) { - // Session already is in that room, no action needed. - roomSessionId := message.Room.SessionId - if roomSessionId == "" { - // TODO(jojo): Better make the session id required in the request. - log.Printf("User did not send a room session id, assuming session %s", session.PublicId()) - roomSessionId = session.PublicId() - } + if session == nil { + session.SendMessage(message.NewErrorServerMessage( + NewError("not_authenticated", "Need to authenticate before joining rooms."), + )) + return + } - if err := session.UpdateRoomSessionId(roomSessionId); err != nil { - log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err) - } - session.SendMessage(message.NewErrorServerMessage( - NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{ - Room: &RoomServerMessage{ - RoomId: room.id, - Properties: room.properties, - }, - }), - )) - return + if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) { + // Session already is in that room, no action needed. + roomSessionId := message.Room.SessionId + if roomSessionId == "" { + // TODO(jojo): Better make the session id required in the request. + log.Printf("User did not send a room session id, assuming session %s", session.PublicId()) + roomSessionId = session.PublicId() } + + if err := session.UpdateRoomSessionId(roomSessionId); err != nil { + log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err) + } + session.SendMessage(message.NewErrorServerMessage( + NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{ + Room: &RoomServerMessage{ + RoomId: room.id, + Properties: room.properties, + }, + }), + )) + return } var room BackendClientResponse @@ -1430,14 +1448,14 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro r.AddSession(session, room.Room.Session) } -func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { - msg := message.Message - session := client.GetSession() - if session == nil { +func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) { + session, ok := client.GetSession().(*ClientSession) + if session == nil || !ok { // Client is not connected yet. return } + msg := message.Message var recipient *ClientSession var subject string var clientData *MessageClientMessageData @@ -1484,10 +1502,10 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) { // User is stopping to share his screen. Firefox doesn't properly clean // up the peer connections in all cases, so make sure to stop publishing // in the MCU. - go func(c *Client) { + go func(c HandlerClient) { time.Sleep(cleanupScreenPublisherDelay) - session := c.GetSession() - if session == nil { + session, ok := c.GetSession().(*ClientSession) + if session == nil || !ok { return } @@ -1700,7 +1718,7 @@ func isAllowedToControl(session Session) bool { return false } -func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { +func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) { msg := message.Control session := client.GetSession() if session == nil { @@ -1813,10 +1831,10 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) { } } -func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) { +func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) { msg := message.Internal - session := client.GetSession() - if session == nil { + session, ok := client.GetSession().(*ClientSession) + if session == nil || !ok { // Client is not connected yet. return } else if session.ClientType() != HelloClientTypeInternal { @@ -2030,7 +2048,7 @@ func isAllowedToUpdateTransientData(session Session) bool { return false } -func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) { +func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage) { msg := message.TransientData session := client.GetSession() if session == nil { @@ -2070,17 +2088,17 @@ func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) { } } -func sendNotAllowed(session *ClientSession, message *ClientMessage, reason string) { +func sendNotAllowed(session Session, message *ClientMessage, reason string) { response := message.NewErrorServerMessage(NewError("not_allowed", reason)) session.SendMessage(response) } -func sendMcuClientNotFound(session *ClientSession, message *ClientMessage) { +func sendMcuClientNotFound(session Session, message *ClientMessage) { response := message.NewErrorServerMessage(NewError("client_not_found", "No MCU client found to send message to.")) session.SendMessage(response) } -func sendMcuProcessingFailed(session *ClientSession, message *ClientMessage) { +func sendMcuProcessingFailed(session Session, message *ClientMessage) { response := message.NewErrorServerMessage(NewError("processing_failed", "Processing of the message failed, please check server logs.")) session.SendMessage(response) } @@ -2295,7 +2313,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient McuClient session.SendMessage(response_message) } -func (h *Hub) processByeMsg(client *Client, message *ClientMessage) { +func (h *Hub) processByeMsg(client HandlerClient, message *ClientMessage) { client.SendByeResponse(message) if session := h.processUnregister(client); session != nil { session.Close() @@ -2412,7 +2430,7 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) { }(h) } -func (h *Hub) OnLookupCountry(client *Client) string { +func (h *Hub) OnLookupCountry(client HandlerClient) string { ip := net.ParseIP(client.RemoteAddr()) if ip == nil { return noCountry @@ -2444,14 +2462,14 @@ func (h *Hub) OnLookupCountry(client *Client) string { return country } -func (h *Hub) OnClosed(client *Client) { +func (h *Hub) OnClosed(client HandlerClient) { h.processUnregister(client) } -func (h *Hub) OnMessageReceived(client *Client, data []byte) { +func (h *Hub) OnMessageReceived(client HandlerClient, data []byte) { h.processMessage(client, data) } -func (h *Hub) OnRTTReceived(client *Client, rtt time.Duration) { +func (h *Hub) OnRTTReceived(client HandlerClient, rtt time.Duration) { // Ignore } diff --git a/proxy/proxy_client.go b/proxy/proxy_client.go index dde4de8..cee7328 100644 --- a/proxy/proxy_client.go +++ b/proxy/proxy_client.go @@ -53,18 +53,18 @@ func (c *ProxyClient) SetSession(session *ProxySession) { c.session.Store(session) } -func (c *ProxyClient) OnClosed(client *signaling.Client) { +func (c *ProxyClient) OnClosed(client signaling.HandlerClient) { if session := c.GetSession(); session != nil { session.MarkUsed() } c.proxy.clientClosed(&c.Client) } -func (c *ProxyClient) OnMessageReceived(client *signaling.Client, data []byte) { +func (c *ProxyClient) OnMessageReceived(client signaling.HandlerClient, data []byte) { c.proxy.processMessage(c, data) } -func (c *ProxyClient) OnRTTReceived(client *signaling.Client, rtt time.Duration) { +func (c *ProxyClient) OnRTTReceived(client signaling.HandlerClient, rtt time.Duration) { if session := c.GetSession(); session != nil { session.MarkUsed() } diff --git a/roomsessions_test.go b/roomsessions_test.go index 5a8ffe0..805fa5b 100644 --- a/roomsessions_test.go +++ b/roomsessions_test.go @@ -86,6 +86,14 @@ func (s *DummySession) HasPermission(permission Permission) bool { return false } +func (s *DummySession) SendError(e *Error) bool { + return false +} + +func (s *DummySession) SendMessage(message *ServerMessage) bool { + return false +} + func checkSession(t *testing.T, sessions RoomSessions, sessionId string, roomSessionId string) Session { session := &DummySession{ publicId: sessionId, diff --git a/session.go b/session.go index 79286a3..5fd36c2 100644 --- a/session.go +++ b/session.go @@ -72,4 +72,7 @@ type Session interface { Close() HasPermission(permission Permission) bool + + SendError(e *Error) bool + SendMessage(message *ServerMessage) bool } diff --git a/testclient_test.go b/testclient_test.go index 23f1d63..9fcc5fd 100644 --- a/testclient_test.go +++ b/testclient_test.go @@ -311,12 +311,14 @@ func (c *TestClient) WaitForClientRemoved(ctx context.Context) error { for { found := false for _, client := range c.hub.clients { - client.mu.Lock() - conn := client.conn - client.mu.Unlock() - if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() { - found = true - break + if cc, ok := client.(*Client); ok { + cc.mu.Lock() + conn := cc.conn + cc.mu.Unlock() + if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() { + found = true + break + } } } if !found { diff --git a/virtualsession.go b/virtualsession.go index 6288396..7d17e6e 100644 --- a/virtualsession.go +++ b/virtualsession.go @@ -51,7 +51,7 @@ type VirtualSession struct { options *AddSessionOptions } -func GetVirtualSessionId(session *ClientSession, sessionId string) string { +func GetVirtualSessionId(session Session, sessionId string) string { return session.PublicId() + "|" + sessionId } @@ -163,7 +163,7 @@ func (s *VirtualSession) Close() { s.CloseWithFeedback(nil, nil) } -func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *ClientMessage) { +func (s *VirtualSession) CloseWithFeedback(session Session, message *ClientMessage) { room := s.GetRoom() s.session.RemoveVirtualSession(s) removed := s.session.hub.removeSession(s) @@ -173,7 +173,7 @@ func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *Clie s.session.events.UnregisterSessionListener(s.PublicId(), s.session.Backend(), s) } -func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession, message *ClientMessage) { +func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, message *ClientMessage) { ctx, cancel := context.WithTimeout(context.Background(), s.hub.backendTimeout) defer cancel() @@ -321,3 +321,11 @@ func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) { } } } + +func (s *VirtualSession) SendError(e *Error) bool { + return s.session.SendError(e) +} + +func (s *VirtualSession) SendMessage(message *ServerMessage) bool { + return s.session.SendMessage(message) +} From 0c2cefa63a2f69033849caa4c704fb84a00f7267 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 11:09:04 +0200 Subject: [PATCH 3/4] Don't return "false" if message sending closed the connection. --- client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client.go b/client.go index d918127..948ad97 100644 --- a/client.go +++ b/client.go @@ -497,7 +497,6 @@ func (c *Client) writeMessageLocked(message WritableClientMessage) bool { go session.Close() } go c.Close() - return false } return true From 602452fa250e70529f022546129ddbb0278b3e0d Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 23 Apr 2024 11:46:32 +0200 Subject: [PATCH 4/4] Support resuming sessions that exist on a different Hub in the cluster. --- grpc_client.go | 104 ++++++++++++++ grpc_remote_client.go | 225 ++++++++++++++++++++++++++++++ grpc_server.go | 27 ++++ grpc_sessions.proto | 18 +++ hub.go | 131 ++++++++++++++++++ hub_test.go | 311 +++++++++++++++++++++++++++++++++++++++++- remotesession.go | 152 +++++++++++++++++++++ 7 files changed, 966 insertions(+), 2 deletions(-) create mode 100644 grpc_remote_client.go create mode 100644 remotesession.go diff --git a/grpc_client.go b/grpc_client.go index f2efa8d..0774ed7 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -26,6 +26,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "log" "net" "net/url" @@ -39,6 +40,7 @@ import ( "google.golang.org/grpc" codes "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" status "google.golang.org/grpc/status" ) @@ -51,6 +53,8 @@ const ( ) var ( + ErrNoSuchResumeId = fmt.Errorf("unknown resume id") + customResolverPrefix atomic.Uint64 ) @@ -185,6 +189,26 @@ func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) { return response.GetServerId(), nil } +func (c *GrpcClient) LookupResumeId(ctx context.Context, resumeId string) (*LookupResumeIdReply, error) { + statsGrpcClientCalls.WithLabelValues("LookupResumeId").Inc() + // TODO: Remove debug logging + log.Printf("Lookup resume id %s on %s", resumeId, c.Target()) + response, err := c.impl.LookupResumeId(ctx, &LookupResumeIdRequest{ + ResumeId: resumeId, + }, grpc.WaitForReady(true)) + if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + return nil, ErrNoSuchResumeId + } else if err != nil { + return nil, err + } + + if sessionId := response.GetSessionId(); sessionId == "" { + return nil, ErrNoSuchResumeId + } + + return response, nil +} + func (c *GrpcClient) LookupSessionId(ctx context.Context, roomSessionId string, disconnectReason string) (string, error) { statsGrpcClientCalls.WithLabelValues("LookupSessionId").Inc() // TODO: Remove debug logging @@ -258,6 +282,86 @@ func (c *GrpcClient) GetSessionCount(ctx context.Context, u *url.URL) (uint32, e return response.GetCount(), nil } +type ProxySessionReceiver interface { + RemoteAddr() string + Country() string + UserAgent() string + + OnProxyMessage(message *ServerSessionMessage) error + OnProxyClose(err error) +} + +type SessionProxy struct { + sessionId string + receiver ProxySessionReceiver + + sendMu sync.Mutex + client RpcSessions_ProxySessionClient +} + +func (p *SessionProxy) recvPump() { + var closeError error + defer func() { + p.receiver.OnProxyClose(closeError) + if err := p.Close(); err != nil { + log.Printf("Error closing proxy for session %s: %s", p.sessionId, err) + } + }() + + for { + msg, err := p.client.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + log.Printf("Error receiving message from proxy for session %s: %s", p.sessionId, err) + closeError = err + break + } + + if err := p.receiver.OnProxyMessage(msg); err != nil { + log.Printf("Error processing message %+v from proxy for session %s: %s", msg, p.sessionId, err) + } + } +} + +func (p *SessionProxy) Send(message *ClientSessionMessage) error { + p.sendMu.Lock() + defer p.sendMu.Unlock() + return p.client.Send(message) +} + +func (p *SessionProxy) Close() error { + p.sendMu.Lock() + defer p.sendMu.Unlock() + return p.client.CloseSend() +} + +func (c *GrpcClient) ProxySession(ctx context.Context, sessionId string, receiver ProxySessionReceiver) (*SessionProxy, error) { + statsGrpcClientCalls.WithLabelValues("ProxySession").Inc() + md := metadata.Pairs( + "sessionId", sessionId, + "remoteAddr", receiver.RemoteAddr(), + "country", receiver.Country(), + "userAgent", receiver.UserAgent(), + ) + client, err := c.impl.ProxySession(metadata.NewOutgoingContext(ctx, md), grpc.WaitForReady(true)) + if err != nil { + return nil, err + } + + proxy := &SessionProxy{ + sessionId: sessionId, + receiver: receiver, + + client: client, + } + + go proxy.recvPump() + return proxy, nil +} + type grpcClientsList struct { clients []*GrpcClient entry *DnsMonitorEntry diff --git a/grpc_remote_client.go b/grpc_remote_client.go new file mode 100644 index 0000000..e32c6c6 --- /dev/null +++ b/grpc_remote_client.go @@ -0,0 +1,225 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "sync/atomic" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +const ( + grpcRemoteClientMessageQueue = 16 +) + +func getMD(md metadata.MD, key string) string { + if values := md.Get(key); len(values) > 0 { + return values[0] + } + + return "" +} + +// remoteGrpcClient is a remote client connecting from a GRPC proxy to a Hub. +type remoteGrpcClient struct { + hub *Hub + client RpcSessions_ProxySessionServer + + sessionId string + remoteAddr string + country string + userAgent string + + closeCtx context.Context + closeFunc context.CancelCauseFunc + + session atomic.Pointer[Session] + messages chan WritableClientMessage +} + +func newRemoteGrpcClient(hub *Hub, request RpcSessions_ProxySessionServer) (*remoteGrpcClient, error) { + md, found := metadata.FromIncomingContext(request.Context()) + if !found { + return nil, errors.New("no metadata provided") + } + + closeCtx, closeFunc := context.WithCancelCause(context.Background()) + + result := &remoteGrpcClient{ + hub: hub, + client: request, + + sessionId: getMD(md, "sessionId"), + remoteAddr: getMD(md, "remoteAddr"), + country: getMD(md, "country"), + userAgent: getMD(md, "userAgent"), + + closeCtx: closeCtx, + closeFunc: closeFunc, + + messages: make(chan WritableClientMessage, grpcRemoteClientMessageQueue), + } + return result, nil +} + +func (c *remoteGrpcClient) readPump() { + var closeError error + defer func() { + c.closeFunc(closeError) + c.hub.OnClosed(c) + }() + + for { + msg, err := c.client.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + // Connection was closed locally. + break + } + + if status.Code(err) != codes.Canceled { + log.Printf("Error reading from remote client for session %s: %s", c.sessionId, err) + closeError = err + } + break + } + + c.hub.OnMessageReceived(c, msg.Message) + } +} + +func (c *remoteGrpcClient) RemoteAddr() string { + return c.remoteAddr +} + +func (c *remoteGrpcClient) UserAgent() string { + return c.userAgent +} + +func (c *remoteGrpcClient) Country() string { + return c.country +} + +func (c *remoteGrpcClient) IsConnected() bool { + return true +} + +func (c *remoteGrpcClient) IsAuthenticated() bool { + return c.GetSession() != nil +} + +func (c *remoteGrpcClient) GetSession() Session { + session := c.session.Load() + if session == nil { + return nil + } + + return *session +} + +func (c *remoteGrpcClient) SetSession(session Session) { + if session == nil { + c.session.Store(nil) + } else { + c.session.Store(&session) + } +} + +func (c *remoteGrpcClient) SendError(e *Error) bool { + message := &ServerMessage{ + Type: "error", + Error: e, + } + return c.SendMessage(message) +} + +func (c *remoteGrpcClient) SendByeResponse(message *ClientMessage) bool { + return c.SendByeResponseWithReason(message, "") +} + +func (c *remoteGrpcClient) SendByeResponseWithReason(message *ClientMessage, reason string) bool { + response := &ServerMessage{ + Type: "bye", + } + if message != nil { + response.Id = message.Id + } + if reason != "" { + if response.Bye == nil { + response.Bye = &ByeServerMessage{} + } + response.Bye.Reason = reason + } + return c.SendMessage(response) +} + +func (c *remoteGrpcClient) SendMessage(message WritableClientMessage) bool { + if c.closeCtx.Err() != nil { + return false + } + + select { + case c.messages <- message: + return true + default: + log.Printf("Message queue for remote client of session %s is full, not sending %+v", c.sessionId, message) + return false + } +} + +func (c *remoteGrpcClient) Close() { + c.closeFunc(nil) +} + +func (c *remoteGrpcClient) run() error { + go c.readPump() + + for { + select { + case <-c.closeCtx.Done(): + if err := context.Cause(c.closeCtx); err != context.Canceled { + return err + } + return nil + case msg := <-c.messages: + data, err := json.Marshal(msg) + if err != nil { + log.Printf("Error marshalling %+v for remote client for session %s: %s", msg, c.sessionId, err) + continue + } + + if err := c.client.Send(&ServerSessionMessage{ + Message: data, + }); err != nil { + return fmt.Errorf("error sending %+v to remote client for session %s: %w", msg, c.sessionId, err) + } + } + } +} diff --git a/grpc_server.go b/grpc_server.go index 6dd01e9..236467d 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -113,6 +113,20 @@ func (s *GrpcServer) Close() { } } +func (s *GrpcServer) LookupResumeId(ctx context.Context, request *LookupResumeIdRequest) (*LookupResumeIdReply, error) { + statsGrpcServerCalls.WithLabelValues("LookupResumeId").Inc() + // TODO: Remove debug logging + log.Printf("Lookup session for resume id %s", request.ResumeId) + session := s.hub.GetSessionByResumeId(request.ResumeId) + if session == nil { + return nil, status.Error(codes.NotFound, "no such room session id") + } + + return &LookupResumeIdReply{ + SessionId: session.PublicId(), + }, nil +} + func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) { statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc() // TODO: Remove debug logging @@ -216,3 +230,16 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCou Count: uint32(backend.Len()), }, nil } + +func (s *GrpcServer) ProxySession(request RpcSessions_ProxySessionServer) error { + statsGrpcServerCalls.WithLabelValues("ProxySession").Inc() + client, err := newRemoteGrpcClient(s.hub, request) + if err != nil { + return err + } + + sid := s.hub.registerClient(client) + defer s.hub.unregisterClient(sid) + + return client.run() +} diff --git a/grpc_sessions.proto b/grpc_sessions.proto index 9eef15a..4dbfab4 100644 --- a/grpc_sessions.proto +++ b/grpc_sessions.proto @@ -26,8 +26,18 @@ option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling" package signaling; service RpcSessions { + rpc LookupResumeId(LookupResumeIdRequest) returns (LookupResumeIdReply) {} rpc LookupSessionId(LookupSessionIdRequest) returns (LookupSessionIdReply) {} rpc IsSessionInCall(IsSessionInCallRequest) returns (IsSessionInCallReply) {} + rpc ProxySession(stream ClientSessionMessage) returns (stream ServerSessionMessage) {} +} + +message LookupResumeIdRequest { + string resumeId = 1; +} + +message LookupResumeIdReply { + string sessionId = 1; } message LookupSessionIdRequest { @@ -49,3 +59,11 @@ message IsSessionInCallRequest { message IsSessionInCallReply { bool inCall = 1; } + +message ClientSessionMessage { + bytes message = 1; +} + +message ServerSessionMessage { + bytes message = 1; +} diff --git a/hub.go b/hub.go index e92b566..f26bd23 100644 --- a/hub.go +++ b/hub.go @@ -155,6 +155,7 @@ type Hub struct { anonymousSessions map[*ClientSession]time.Time expectHelloClients map[HandlerClient]time.Time dialoutSessions map[*ClientSession]bool + remoteSessions map[*RemoteSession]bool backendTimeout time.Duration backend *BackendClient @@ -343,6 +344,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer anonymousSessions: make(map[*ClientSession]time.Time), expectHelloClients: make(map[HandlerClient]time.Time), dialoutSessions: make(map[*ClientSession]bool), + remoteSessions: make(map[*RemoteSession]bool), backendTimeout: backendTimeout, backend: backend, @@ -584,6 +586,22 @@ func (h *Hub) GetSessionByPublicId(sessionId string) Session { return session } +func (h *Hub) GetSessionByResumeId(resumeId string) Session { + data := h.decodeSessionId(resumeId, privateSessionName) + if data == nil { + return nil + } + + h.mu.RLock() + defer h.mu.RUnlock() + session := h.sessions[data.Sid] + if session != nil && session.PrivateId() != resumeId { + // Session was created on different server. + return nil + } + return session +} + func (h *Hub) GetDialoutSession(roomId string, backend *Backend) *ClientSession { url := backend.Url() @@ -715,6 +733,30 @@ func (h *Hub) sendWelcome(client HandlerClient) { client.SendMessage(h.getWelcomeMessage()) } +func (h *Hub) registerClient(client HandlerClient) uint64 { + sid := h.sid.Add(1) + for sid == 0 { + sid = h.sid.Add(1) + } + + h.mu.Lock() + defer h.mu.Unlock() + h.clients[sid] = client + return sid +} + +func (h *Hub) unregisterClient(sid uint64) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.clients, sid) +} + +func (h *Hub) unregisterRemoteSession(session *RemoteSession) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.remoteSessions, session) +} + func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData { sid := h.sid.Add(1) for sid == 0 { @@ -953,12 +995,97 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage) return session.SendMessage(response) } +type remoteClientInfo struct { + client *GrpcClient + response *LookupResumeIdReply +} + +func (h *Hub) tryProxyResume(c HandlerClient, resumeId string, message *ClientMessage) bool { + client, ok := c.(*Client) + if !ok { + return false + } + + var clients []*GrpcClient + if h.rpcClients != nil { + clients = h.rpcClients.GetClients() + } + if len(clients) == 0 { + return false + } + + rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer rpcCancel() + + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(rpcCtx) + defer cancel() + + var remoteClient atomic.Pointer[remoteClientInfo] + for _, c := range clients { + wg.Add(1) + go func(client *GrpcClient) { + defer wg.Done() + + if client.IsSelf() { + return + } + + response, err := client.LookupResumeId(ctx, resumeId) + if err != nil { + log.Printf("Could not lookup resume id %s on %s: %s", resumeId, client.Target(), err) + return + } + + cancel() + remoteClient.CompareAndSwap(nil, &remoteClientInfo{ + client: client, + response: response, + }) + }(c) + } + wg.Wait() + + if !client.IsConnected() { + // Client disconnected while checking message. + return false + } + + info := remoteClient.Load() + if info == nil { + return false + } + + rs, err := NewRemoteSession(h, client, info.client, info.response.SessionId) + if err != nil { + log.Printf("Could not create remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err) + return false + } + + if err := rs.Start(message); err != nil { + rs.Close() + log.Printf("Could not start remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err) + return false + } + + log.Printf("Proxy session %s to %s", info.response.SessionId, info.client.Target()) + h.mu.Lock() + defer h.mu.Unlock() + h.remoteSessions[rs] = true + delete(h.expectHelloClients, client) + return true +} + func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { resumeId := message.Hello.ResumeId if resumeId != "" { data := h.decodeSessionId(resumeId, privateSessionName) if data == nil { statsHubSessionResumeFailed.Inc() + if h.tryProxyResume(client, resumeId, message) { + return + } + client.SendMessage(message.NewErrorServerMessage(NoSuchSession)) return } @@ -968,6 +1095,10 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { if !found || resumeId != session.PrivateId() { h.mu.Unlock() statsHubSessionResumeFailed.Inc() + if h.tryProxyResume(client, resumeId, message) { + return + } + client.SendMessage(message.NewErrorServerMessage(NoSuchSession)) return } diff --git a/hub_test.go b/hub_test.go index 9321733..8419ac9 100644 --- a/hub_test.go +++ b/hub_test.go @@ -274,13 +274,19 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) { h.mu.Lock() clients := len(h.clients) sessions := len(h.sessions) + remoteSessions := len(h.remoteSessions) h.mu.Unlock() h.ru.Lock() rooms := len(h.rooms) h.ru.Unlock() readActive := h.readPumpActive.Load() writeActive := h.writePumpActive.Load() - if clients == 0 && rooms == 0 && sessions == 0 && readActive == 0 && writeActive == 0 { + if clients == 0 && + rooms == 0 && + sessions == 0 && + remoteSessions == 0 && + readActive == 0 && + writeActive == 0 { break } @@ -289,7 +295,7 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) { h.mu.Lock() h.ru.Lock() dumpGoroutines("", os.Stderr) - t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, readActive, writeActive, ctx.Err()) + t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / remoteSessions %v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, h.remoteSessions, readActive, writeActive, ctx.Err()) h.ru.Unlock() h.mu.Unlock() return @@ -1892,6 +1898,307 @@ func TestClientHelloResumeAndJoin(t *testing.T) { } } +func runGrpcProxyTest(t *testing.T, f func(hub1, hub2 *Hub, server1, server2 *httptest.Server)) { + t.Helper() + + var hub1 *Hub + var hub2 *Hub + var server1 *httptest.Server + var server2 *httptest.Server + var router1 *mux.Router + var router2 *mux.Router + hub1, hub2, router1, router2, server1, server2 = CreateClusteredHubsForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { + // Make sure all backends use the same server + if server1 == nil { + server1 = server + } else { + server = server1 + } + + config, err := getTestConfig(server) + if err != nil { + return nil, err + } + + config.RemoveOption("backend", "allowed") + config.RemoveOption("backend", "secret") + config.AddOption("backend", "backends", "backend1") + + config.AddOption("backend1", "url", server.URL) + config.AddOption("backend1", "secret", string(testBackendSecret)) + config.AddOption("backend1", "sessionlimit", "1") + return config, nil + }) + + registerBackendHandlerUrl(t, router1, "/") + registerBackendHandlerUrl(t, router2, "/") + + f(hub1, hub2, server1, server2) +} + +func TestClientHelloResumeProxy(t *testing.T) { + ensureNoGoroutinesLeak(t, func(t *testing.T) { + runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { + client1 := NewTestClient(t, server1, hub1) + defer client1.CloseWithBye() + + if err := client1.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + hello, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } else { + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + if hello.Hello.ResumeId == "" { + t.Errorf("Expected resume id, got %+v", hello.Hello) + } + } + + client1.Close() + if err := client1.WaitForClientRemoved(ctx); err != nil { + t.Error(err) + } + + client2 := NewTestClient(t, server2, hub2) + defer client2.CloseWithBye() + + if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { + t.Fatal(err) + } + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Error(err) + } else { + if hello2.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) + } + if hello2.Hello.SessionId != hello.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) + } + if hello2.Hello.ResumeId != hello.Hello.ResumeId { + t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) + } + } + + // Join room by id. + roomId := "test-room" + if room, err := client2.JoinRoom(ctx, roomId); err != nil { + t.Fatal(err) + } else if room.Room.RoomId != roomId { + t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId) + } + + // We will receive a "joined" event. + if err := client2.RunUntilJoined(ctx, hello.Hello); err != nil { + t.Error(err) + } + + if room := hub1.getRoom(roomId); room == nil { + t.Fatalf("Could not find room %s", roomId) + } + if room := hub2.getRoom(roomId); room != nil { + t.Fatalf("Should not have gotten room %s, got %+v", roomId, room) + } + + users := []map[string]interface{}{ + { + "sessionId": "the-session-id", + "inCall": 1, + }, + } + room := hub1.getRoom(roomId) + if room == nil { + t.Fatalf("Could not find room %s", roomId) + } + room.PublishUsersInCallChanged(users, users) + if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil { + t.Error(err) + } + }) + }) +} + +func TestClientHelloResumeProxy_Takeover(t *testing.T) { + ensureNoGoroutinesLeak(t, func(t *testing.T) { + runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { + client1 := NewTestClient(t, server1, hub1) + defer client1.CloseWithBye() + + if err := client1.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + hello, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } else { + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + if hello.Hello.ResumeId == "" { + t.Errorf("Expected resume id, got %+v", hello.Hello) + } + } + + client2 := NewTestClient(t, server2, hub2) + defer client2.CloseWithBye() + + if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { + t.Fatal(err) + } + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Error(err) + } else { + if hello2.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) + } + if hello2.Hello.SessionId != hello.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) + } + if hello2.Hello.ResumeId != hello.Hello.ResumeId { + t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) + } + } + + // The first client got disconnected with a reason in a "Bye" message. + if msg, err := client1.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else { + if msg.Type != "bye" || msg.Bye == nil { + t.Errorf("Expected bye message, got %+v", msg) + } else if msg.Bye.Reason != "session_resumed" { + t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason) + } + } + + if msg, err := client1.RunUntilMessage(ctx); err == nil { + t.Errorf("Expected error but received %+v", msg) + } else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { + t.Errorf("Expected close error but received %+v", err) + } + + client3 := NewTestClient(t, server1, hub1) + defer client3.CloseWithBye() + + if err := client3.SendHelloResume(hello.Hello.ResumeId); err != nil { + t.Fatal(err) + } + hello3, err := client3.RunUntilHello(ctx) + if err != nil { + t.Error(err) + } else { + if hello3.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) + } + if hello3.Hello.SessionId != hello.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) + } + if hello3.Hello.ResumeId != hello.Hello.ResumeId { + t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) + } + } + + // The second client got disconnected with a reason in a "Bye" message. + if msg, err := client2.RunUntilMessage(ctx); err != nil { + t.Error(err) + } else { + if msg.Type != "bye" || msg.Bye == nil { + t.Errorf("Expected bye message, got %+v", msg) + } else if msg.Bye.Reason != "session_resumed" { + t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason) + } + } + + if msg, err := client2.RunUntilMessage(ctx); err == nil { + t.Errorf("Expected error but received %+v", msg) + } else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { + t.Errorf("Expected close error but received %+v", err) + } + }) + }) +} + +func TestClientHelloResumeProxy_Disconnect(t *testing.T) { + ensureNoGoroutinesLeak(t, func(t *testing.T) { + runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) { + client1 := NewTestClient(t, server1, hub1) + defer client1.CloseWithBye() + + if err := client1.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + hello, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } else { + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + if hello.Hello.ResumeId == "" { + t.Errorf("Expected resume id, got %+v", hello.Hello) + } + } + + client1.Close() + if err := client1.WaitForClientRemoved(ctx); err != nil { + t.Error(err) + } + + client2 := NewTestClient(t, server2, hub2) + defer client2.CloseWithBye() + + if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil { + t.Fatal(err) + } + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Error(err) + } else { + if hello2.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello) + } + if hello2.Hello.SessionId != hello.Hello.SessionId { + t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello) + } + if hello2.Hello.ResumeId != hello.Hello.ResumeId { + t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello) + } + } + + // Simulate unclean shutdown of second instance. + hub2.rpcServer.conn.Stop() + + if err := client2.WaitForClientRemoved(ctx); err != nil { + t.Error(err) + } + }) + }) +} + func TestClientHelloClient(t *testing.T) { hub, _, _, server := CreateHubForTest(t) diff --git a/remotesession.go b/remotesession.go new file mode 100644 index 0000000..3aefbba --- /dev/null +++ b/remotesession.go @@ -0,0 +1,152 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "sync/atomic" + "time" +) + +type RemoteSession struct { + hub *Hub + client *Client + remoteClient *GrpcClient + sessionId string + + proxy atomic.Pointer[SessionProxy] +} + +func NewRemoteSession(hub *Hub, client *Client, remoteClient *GrpcClient, sessionId string) (*RemoteSession, error) { + remoteSession := &RemoteSession{ + hub: hub, + client: client, + remoteClient: remoteClient, + sessionId: sessionId, + } + + client.SetSessionId(sessionId) + client.SetHandler(remoteSession) + + proxy, err := remoteClient.ProxySession(context.Background(), sessionId, remoteSession) + if err != nil { + return nil, err + } + remoteSession.proxy.Store(proxy) + + return remoteSession, nil +} + +func (s *RemoteSession) Country() string { + return s.client.Country() +} + +func (s *RemoteSession) RemoteAddr() string { + return s.client.RemoteAddr() +} + +func (s *RemoteSession) UserAgent() string { + return s.client.UserAgent() +} + +func (s *RemoteSession) IsConnected() bool { + return true +} + +func (s *RemoteSession) Start(message *ClientMessage) error { + return s.sendMessage(message) +} + +func (s *RemoteSession) OnProxyMessage(msg *ServerSessionMessage) error { + var message *ServerMessage + if err := json.Unmarshal(msg.Message, &message); err != nil { + return err + } + + if !s.client.SendMessage(message) { + return fmt.Errorf("could not send message to client") + } + + return nil +} + +func (s *RemoteSession) OnProxyClose(err error) { + if err != nil { + log.Printf("Proxy connection for session %s to %s was closed with error: %s", s.sessionId, s.remoteClient.Target(), err) + } + s.Close() +} + +func (s *RemoteSession) SendMessage(message WritableClientMessage) bool { + return s.sendMessage(message) == nil +} + +func (s *RemoteSession) sendProxyMessage(message []byte) error { + proxy := s.proxy.Load() + if proxy == nil { + return errors.New("proxy already closed") + } + + msg := &ClientSessionMessage{ + Message: message, + } + return proxy.Send(msg) +} + +func (s *RemoteSession) sendMessage(message interface{}) error { + data, err := json.Marshal(message) + if err != nil { + return err + } + + return s.sendProxyMessage(data) +} + +func (s *RemoteSession) Close() { + if proxy := s.proxy.Swap(nil); proxy != nil { + proxy.Close() + } + s.hub.unregisterRemoteSession(s) + s.client.Close() +} + +func (s *RemoteSession) OnLookupCountry(client HandlerClient) string { + return s.hub.OnLookupCountry(client) +} + +func (s *RemoteSession) OnClosed(client HandlerClient) { + s.Close() +} + +func (s *RemoteSession) OnMessageReceived(client HandlerClient, message []byte) { + if err := s.sendProxyMessage(message); err != nil { + log.Printf("Error sending %s to the proxy for session %s: %s", string(message), s.sessionId, err) + s.Close() + } +} + +func (s *RemoteSession) OnRTTReceived(client HandlerClient, rtt time.Duration) { +}