Merge pull request #715 from strukturag/resume-remote

Support resuming remote sessions
This commit is contained in:
Joachim Bauch 2024-04-23 11:58:07 +02:00 committed by GitHub
commit d368a060fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1188 additions and 134 deletions

View file

@ -387,7 +387,7 @@ type HelloClientMessage struct {
Features []string `json:"features,omitempty"`
// The authentication credentials.
Auth HelloClientMessageAuth `json:"auth"`
Auth *HelloClientMessageAuth `json:"auth,omitempty"`
}
func (m *HelloClientMessage) CheckValid() error {
@ -395,7 +395,7 @@ func (m *HelloClientMessage) CheckValid() error {
return InvalidHelloVersion
}
if m.ResumeId == "" {
if m.Auth.Params == nil || len(*m.Auth.Params) == 0 {
if m.Auth == nil || m.Auth.Params == nil || len(*m.Auth.Params) == 0 {
return fmt.Errorf("params missing")
}
if m.Auth.Type == "" {

View file

@ -95,14 +95,14 @@ func TestHelloClientMessage(t *testing.T) {
// Hello version 1
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: &json.RawMessage{'{', '}'},
Url: "https://domain.invalid",
},
},
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Type: "client",
Params: &json.RawMessage{'{', '}'},
Url: "https://domain.invalid",
@ -110,7 +110,7 @@ func TestHelloClientMessage(t *testing.T) {
},
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Type: "internal",
Params: (*json.RawMessage)(&internalAuthParams),
},
@ -122,14 +122,14 @@ func TestHelloClientMessage(t *testing.T) {
// Hello version 2
&HelloClientMessage{
Version: HelloVersionV2,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: (*json.RawMessage)(&tokenAuthParams),
Url: "https://domain.invalid",
},
},
&HelloClientMessage{
Version: HelloVersionV2,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Type: "client",
Params: (*json.RawMessage)(&tokenAuthParams),
Url: "https://domain.invalid",
@ -147,40 +147,40 @@ func TestHelloClientMessage(t *testing.T) {
&HelloClientMessage{Version: HelloVersionV1},
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: &json.RawMessage{'{', '}'},
Type: "invalid-type",
},
},
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Url: "https://domain.invalid",
},
},
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: &json.RawMessage{'{', '}'},
},
},
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: &json.RawMessage{'{', '}'},
Url: "invalid-url",
},
},
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Type: "internal",
Params: &json.RawMessage{'{', '}'},
},
},
&HelloClientMessage{
Version: HelloVersionV1,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Type: "internal",
Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON.
},
@ -188,33 +188,33 @@ func TestHelloClientMessage(t *testing.T) {
// Hello version 2
&HelloClientMessage{
Version: HelloVersionV2,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Url: "https://domain.invalid",
},
},
&HelloClientMessage{
Version: HelloVersionV2,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: (*json.RawMessage)(&tokenAuthParams),
},
},
&HelloClientMessage{
Version: HelloVersionV2,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: (*json.RawMessage)(&tokenAuthParams),
Url: "invalid-url",
},
},
&HelloClientMessage{
Version: HelloVersionV2,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: (*json.RawMessage)(&internalAuthParams),
Url: "https://domain.invalid",
},
},
&HelloClientMessage{
Version: HelloVersionV2,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON.
Url: "https://domain.invalid",
},

107
client.go
View file

@ -92,14 +92,32 @@ type WritableClientMessage interface {
CloseAfterSend(session Session) bool
}
type HandlerClient interface {
RemoteAddr() string
Country() string
UserAgent() string
IsConnected() bool
IsAuthenticated() bool
GetSession() Session
SetSession(session Session)
SendError(e *Error) bool
SendByeResponse(message *ClientMessage) bool
SendByeResponseWithReason(message *ClientMessage, reason string) bool
SendMessage(message WritableClientMessage) bool
Close()
}
type ClientHandler interface {
OnClosed(*Client)
OnMessageReceived(*Client, []byte)
OnRTTReceived(*Client, time.Duration)
OnClosed(HandlerClient)
OnMessageReceived(HandlerClient, []byte)
OnRTTReceived(HandlerClient, time.Duration)
}
type ClientGeoIpHandler interface {
OnLookupCountry(*Client) string
OnLookupCountry(HandlerClient) string
}
type Client struct {
@ -111,7 +129,8 @@ type Client struct {
country *string
logRTT bool
session atomic.Pointer[ClientSession]
session atomic.Pointer[Session]
sessionId atomic.Pointer[string]
mu sync.Mutex
@ -142,12 +161,16 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler
func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) {
c.conn = conn
c.addr = remoteAddress
c.handler = handler
c.SetHandler(handler)
c.closer = NewCloser()
c.messageChan = make(chan *bytes.Buffer, 16)
c.messagesDone = make(chan struct{})
}
func (c *Client) SetHandler(handler ClientHandler) {
c.handler = handler
}
func (c *Client) IsConnected() bool {
return c.closed.Load() == 0
}
@ -156,12 +179,39 @@ func (c *Client) IsAuthenticated() bool {
return c.GetSession() != nil
}
func (c *Client) GetSession() *ClientSession {
return c.session.Load()
func (c *Client) GetSession() Session {
session := c.session.Load()
if session == nil {
return nil
}
return *session
}
func (c *Client) SetSession(session *ClientSession) {
c.session.Store(session)
func (c *Client) SetSession(session Session) {
if session == nil {
c.session.Store(nil)
} else {
c.session.Store(&session)
}
}
func (c *Client) SetSessionId(sessionId string) {
c.sessionId.Store(&sessionId)
}
func (c *Client) GetSessionId() string {
sessionId := c.sessionId.Load()
if sessionId == nil {
session := c.GetSession()
if session == nil {
return ""
}
return session.PublicId()
}
return *sessionId
}
func (c *Client) RemoteAddr() string {
@ -234,12 +284,14 @@ func (c *Client) SendByeResponse(message *ClientMessage) bool {
func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string) bool {
response := &ServerMessage{
Type: "bye",
Bye: &ByeServerMessage{},
}
if message != nil {
response.Id = message.Id
}
if reason != "" {
if response.Bye == nil {
response.Bye = &ByeServerMessage{}
}
response.Bye.Reason = reason
}
return c.SendMessage(response)
@ -277,8 +329,8 @@ func (c *Client) ReadPump() {
rtt := now.Sub(time.Unix(0, ts))
if c.logRTT {
rtt_ms := rtt.Nanoseconds() / time.Millisecond.Nanoseconds()
if session := c.GetSession(); session != nil {
log.Printf("Client %s has RTT of %d ms (%s)", session.PublicId(), rtt_ms, rtt)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Client %s has RTT of %d ms (%s)", sessionId, rtt_ms, rtt)
} else {
log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt)
}
@ -296,8 +348,8 @@ func (c *Client) ReadPump() {
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived) {
if session := c.GetSession(); session != nil {
log.Printf("Error reading from client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Error reading from client %s: %v", sessionId, err)
} else {
log.Printf("Error reading from %s: %v", addr, err)
}
@ -306,8 +358,8 @@ func (c *Client) ReadPump() {
}
if messageType != websocket.TextMessage {
if session := c.GetSession(); session != nil {
log.Printf("Unsupported message type %v from client %s", messageType, session.PublicId())
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Unsupported message type %v from client %s", messageType, sessionId)
} else {
log.Printf("Unsupported message type %v from %s", messageType, addr)
}
@ -319,8 +371,8 @@ func (c *Client) ReadPump() {
decodeBuffer.Reset()
if _, err := decodeBuffer.ReadFrom(reader); err != nil {
bufferPool.Put(decodeBuffer)
if session := c.GetSession(); session != nil {
log.Printf("Error reading message from client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Error reading message from client %s: %v", sessionId, err)
} else {
log.Printf("Error reading message from %s: %v", addr, err)
}
@ -373,8 +425,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool {
return false
}
if session := c.GetSession(); session != nil {
log.Printf("Could not send message %+v to client %s: %v", message, session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Could not send message %+v to client %s: %v", message, sessionId, err)
} else {
log.Printf("Could not send message %+v to %s: %v", message, c.RemoteAddr(), err)
}
@ -386,8 +438,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool {
close:
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil {
if session := c.GetSession(); session != nil {
log.Printf("Could not send close message to client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Could not send close message to client %s: %v", sessionId, err)
} else {
log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err)
}
@ -413,8 +465,8 @@ func (c *Client) writeError(e error) bool { // nolint
closeData := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, e.Error())
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil {
if session := c.GetSession(); session != nil {
log.Printf("Could not send close message to client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Could not send close message to client %s: %v", sessionId, err)
} else {
log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err)
}
@ -445,7 +497,6 @@ func (c *Client) writeMessageLocked(message WritableClientMessage) bool {
go session.Close()
}
go c.Close()
return false
}
return true
@ -462,8 +513,8 @@ func (c *Client) sendPing() bool {
msg := strconv.FormatInt(now, 10)
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil {
if session := c.GetSession(); session != nil {
log.Printf("Could not send ping to client %s: %v", session.PublicId(), err)
if sessionId := c.GetSessionId(); sessionId != "" {
log.Printf("Could not send ping to client %s: %v", sessionId, err)
} else {
log.Printf("Could not send ping to %s: %v", c.RemoteAddr(), err)
}

View file

@ -601,7 +601,7 @@ func main() {
Type: "hello",
Hello: &signaling.HelloClientMessage{
Version: signaling.HelloVersionV1,
Auth: signaling.HelloClientMessageAuth{
Auth: &signaling.HelloClientMessageAuth{
Url: backendUrl + "/auth",
Params: &json.RawMessage{'{', '}'},
},

View file

@ -67,7 +67,7 @@ type ClientSession struct {
mu sync.Mutex
client *Client
client HandlerClient
room atomic.Pointer[Room]
roomJoinTime atomic.Int64
roomSessionId string
@ -500,14 +500,14 @@ func (s *ClientSession) doUnsubscribeRoomEvents(notify bool) {
s.roomSessionId = ""
}
func (s *ClientSession) ClearClient(client *Client) {
func (s *ClientSession) ClearClient(client HandlerClient) {
s.mu.Lock()
defer s.mu.Unlock()
s.clearClientLocked(client)
}
func (s *ClientSession) clearClientLocked(client *Client) {
func (s *ClientSession) clearClientLocked(client HandlerClient) {
if s.client == nil {
return
} else if client != nil && s.client != client {
@ -520,18 +520,18 @@ func (s *ClientSession) clearClientLocked(client *Client) {
prevClient.SetSession(nil)
}
func (s *ClientSession) GetClient() *Client {
func (s *ClientSession) GetClient() HandlerClient {
s.mu.Lock()
defer s.mu.Unlock()
return s.getClientUnlocked()
}
func (s *ClientSession) getClientUnlocked() *Client {
func (s *ClientSession) getClientUnlocked() HandlerClient {
return s.client
}
func (s *ClientSession) SetClient(client *Client) *Client {
func (s *ClientSession) SetClient(client HandlerClient) HandlerClient {
if client == nil {
panic("Use ClearClient to set the client to nil")
}
@ -1341,7 +1341,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage {
}
}
func (s *ClientSession) NotifySessionResumed(client *Client) {
func (s *ClientSession) NotifySessionResumed(client HandlerClient) {
s.mu.Lock()
if len(s.pendingClientMessages) == 0 {
s.mu.Unlock()

View file

@ -26,6 +26,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/url"
@ -39,6 +40,7 @@ import (
"google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
status "google.golang.org/grpc/status"
)
@ -51,6 +53,8 @@ const (
)
var (
ErrNoSuchResumeId = fmt.Errorf("unknown resume id")
customResolverPrefix atomic.Uint64
)
@ -185,6 +189,26 @@ func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) {
return response.GetServerId(), nil
}
func (c *GrpcClient) LookupResumeId(ctx context.Context, resumeId string) (*LookupResumeIdReply, error) {
statsGrpcClientCalls.WithLabelValues("LookupResumeId").Inc()
// TODO: Remove debug logging
log.Printf("Lookup resume id %s on %s", resumeId, c.Target())
response, err := c.impl.LookupResumeId(ctx, &LookupResumeIdRequest{
ResumeId: resumeId,
}, grpc.WaitForReady(true))
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return nil, ErrNoSuchResumeId
} else if err != nil {
return nil, err
}
if sessionId := response.GetSessionId(); sessionId == "" {
return nil, ErrNoSuchResumeId
}
return response, nil
}
func (c *GrpcClient) LookupSessionId(ctx context.Context, roomSessionId string, disconnectReason string) (string, error) {
statsGrpcClientCalls.WithLabelValues("LookupSessionId").Inc()
// TODO: Remove debug logging
@ -258,6 +282,86 @@ func (c *GrpcClient) GetSessionCount(ctx context.Context, u *url.URL) (uint32, e
return response.GetCount(), nil
}
type ProxySessionReceiver interface {
RemoteAddr() string
Country() string
UserAgent() string
OnProxyMessage(message *ServerSessionMessage) error
OnProxyClose(err error)
}
type SessionProxy struct {
sessionId string
receiver ProxySessionReceiver
sendMu sync.Mutex
client RpcSessions_ProxySessionClient
}
func (p *SessionProxy) recvPump() {
var closeError error
defer func() {
p.receiver.OnProxyClose(closeError)
if err := p.Close(); err != nil {
log.Printf("Error closing proxy for session %s: %s", p.sessionId, err)
}
}()
for {
msg, err := p.client.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
log.Printf("Error receiving message from proxy for session %s: %s", p.sessionId, err)
closeError = err
break
}
if err := p.receiver.OnProxyMessage(msg); err != nil {
log.Printf("Error processing message %+v from proxy for session %s: %s", msg, p.sessionId, err)
}
}
}
func (p *SessionProxy) Send(message *ClientSessionMessage) error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.client.Send(message)
}
func (p *SessionProxy) Close() error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.client.CloseSend()
}
func (c *GrpcClient) ProxySession(ctx context.Context, sessionId string, receiver ProxySessionReceiver) (*SessionProxy, error) {
statsGrpcClientCalls.WithLabelValues("ProxySession").Inc()
md := metadata.Pairs(
"sessionId", sessionId,
"remoteAddr", receiver.RemoteAddr(),
"country", receiver.Country(),
"userAgent", receiver.UserAgent(),
)
client, err := c.impl.ProxySession(metadata.NewOutgoingContext(ctx, md), grpc.WaitForReady(true))
if err != nil {
return nil, err
}
proxy := &SessionProxy{
sessionId: sessionId,
receiver: receiver,
client: client,
}
go proxy.recvPump()
return proxy, nil
}
type grpcClientsList struct {
clients []*GrpcClient
entry *DnsMonitorEntry

225
grpc_remote_client.go Normal file
View file

@ -0,0 +1,225 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"sync/atomic"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const (
grpcRemoteClientMessageQueue = 16
)
func getMD(md metadata.MD, key string) string {
if values := md.Get(key); len(values) > 0 {
return values[0]
}
return ""
}
// remoteGrpcClient is a remote client connecting from a GRPC proxy to a Hub.
type remoteGrpcClient struct {
hub *Hub
client RpcSessions_ProxySessionServer
sessionId string
remoteAddr string
country string
userAgent string
closeCtx context.Context
closeFunc context.CancelCauseFunc
session atomic.Pointer[Session]
messages chan WritableClientMessage
}
func newRemoteGrpcClient(hub *Hub, request RpcSessions_ProxySessionServer) (*remoteGrpcClient, error) {
md, found := metadata.FromIncomingContext(request.Context())
if !found {
return nil, errors.New("no metadata provided")
}
closeCtx, closeFunc := context.WithCancelCause(context.Background())
result := &remoteGrpcClient{
hub: hub,
client: request,
sessionId: getMD(md, "sessionId"),
remoteAddr: getMD(md, "remoteAddr"),
country: getMD(md, "country"),
userAgent: getMD(md, "userAgent"),
closeCtx: closeCtx,
closeFunc: closeFunc,
messages: make(chan WritableClientMessage, grpcRemoteClientMessageQueue),
}
return result, nil
}
func (c *remoteGrpcClient) readPump() {
var closeError error
defer func() {
c.closeFunc(closeError)
c.hub.OnClosed(c)
}()
for {
msg, err := c.client.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
// Connection was closed locally.
break
}
if status.Code(err) != codes.Canceled {
log.Printf("Error reading from remote client for session %s: %s", c.sessionId, err)
closeError = err
}
break
}
c.hub.OnMessageReceived(c, msg.Message)
}
}
func (c *remoteGrpcClient) RemoteAddr() string {
return c.remoteAddr
}
func (c *remoteGrpcClient) UserAgent() string {
return c.userAgent
}
func (c *remoteGrpcClient) Country() string {
return c.country
}
func (c *remoteGrpcClient) IsConnected() bool {
return true
}
func (c *remoteGrpcClient) IsAuthenticated() bool {
return c.GetSession() != nil
}
func (c *remoteGrpcClient) GetSession() Session {
session := c.session.Load()
if session == nil {
return nil
}
return *session
}
func (c *remoteGrpcClient) SetSession(session Session) {
if session == nil {
c.session.Store(nil)
} else {
c.session.Store(&session)
}
}
func (c *remoteGrpcClient) SendError(e *Error) bool {
message := &ServerMessage{
Type: "error",
Error: e,
}
return c.SendMessage(message)
}
func (c *remoteGrpcClient) SendByeResponse(message *ClientMessage) bool {
return c.SendByeResponseWithReason(message, "")
}
func (c *remoteGrpcClient) SendByeResponseWithReason(message *ClientMessage, reason string) bool {
response := &ServerMessage{
Type: "bye",
}
if message != nil {
response.Id = message.Id
}
if reason != "" {
if response.Bye == nil {
response.Bye = &ByeServerMessage{}
}
response.Bye.Reason = reason
}
return c.SendMessage(response)
}
func (c *remoteGrpcClient) SendMessage(message WritableClientMessage) bool {
if c.closeCtx.Err() != nil {
return false
}
select {
case c.messages <- message:
return true
default:
log.Printf("Message queue for remote client of session %s is full, not sending %+v", c.sessionId, message)
return false
}
}
func (c *remoteGrpcClient) Close() {
c.closeFunc(nil)
}
func (c *remoteGrpcClient) run() error {
go c.readPump()
for {
select {
case <-c.closeCtx.Done():
if err := context.Cause(c.closeCtx); err != context.Canceled {
return err
}
return nil
case msg := <-c.messages:
data, err := json.Marshal(msg)
if err != nil {
log.Printf("Error marshalling %+v for remote client for session %s: %s", msg, c.sessionId, err)
continue
}
if err := c.client.Send(&ServerSessionMessage{
Message: data,
}); err != nil {
return fmt.Errorf("error sending %+v to remote client for session %s: %w", msg, c.sessionId, err)
}
}
}
}

View file

@ -113,6 +113,20 @@ func (s *GrpcServer) Close() {
}
}
func (s *GrpcServer) LookupResumeId(ctx context.Context, request *LookupResumeIdRequest) (*LookupResumeIdReply, error) {
statsGrpcServerCalls.WithLabelValues("LookupResumeId").Inc()
// TODO: Remove debug logging
log.Printf("Lookup session for resume id %s", request.ResumeId)
session := s.hub.GetSessionByResumeId(request.ResumeId)
if session == nil {
return nil, status.Error(codes.NotFound, "no such room session id")
}
return &LookupResumeIdReply{
SessionId: session.PublicId(),
}, nil
}
func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) {
statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc()
// TODO: Remove debug logging
@ -216,3 +230,16 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCou
Count: uint32(backend.Len()),
}, nil
}
func (s *GrpcServer) ProxySession(request RpcSessions_ProxySessionServer) error {
statsGrpcServerCalls.WithLabelValues("ProxySession").Inc()
client, err := newRemoteGrpcClient(s.hub, request)
if err != nil {
return err
}
sid := s.hub.registerClient(client)
defer s.hub.unregisterClient(sid)
return client.run()
}

View file

@ -26,8 +26,18 @@ option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling"
package signaling;
service RpcSessions {
rpc LookupResumeId(LookupResumeIdRequest) returns (LookupResumeIdReply) {}
rpc LookupSessionId(LookupSessionIdRequest) returns (LookupSessionIdReply) {}
rpc IsSessionInCall(IsSessionInCallRequest) returns (IsSessionInCallReply) {}
rpc ProxySession(stream ClientSessionMessage) returns (stream ServerSessionMessage) {}
}
message LookupResumeIdRequest {
string resumeId = 1;
}
message LookupResumeIdReply {
string sessionId = 1;
}
message LookupSessionIdRequest {
@ -49,3 +59,11 @@ message IsSessionInCallRequest {
message IsSessionInCallReply {
bool inCall = 1;
}
message ClientSessionMessage {
bytes message = 1;
}
message ServerSessionMessage {
bytes message = 1;
}

279
hub.go
View file

@ -135,7 +135,7 @@ type Hub struct {
ru sync.RWMutex
sid atomic.Uint64
clients map[uint64]*Client
clients map[uint64]HandlerClient
sessions map[uint64]Session
rooms map[string]*Room
@ -153,8 +153,9 @@ type Hub struct {
expiredSessions map[Session]time.Time
anonymousSessions map[*ClientSession]time.Time
expectHelloClients map[*Client]time.Time
expectHelloClients map[HandlerClient]time.Time
dialoutSessions map[*ClientSession]bool
remoteSessions map[*RemoteSession]bool
backendTimeout time.Duration
backend *BackendClient
@ -324,7 +325,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
roomInCall: make(chan *BackendServerRoomRequest),
roomParticipants: make(chan *BackendServerRoomRequest),
clients: make(map[uint64]*Client),
clients: make(map[uint64]HandlerClient),
sessions: make(map[uint64]Session),
rooms: make(map[string]*Room),
@ -341,8 +342,9 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
expiredSessions: make(map[Session]time.Time),
anonymousSessions: make(map[*ClientSession]time.Time),
expectHelloClients: make(map[*Client]time.Time),
expectHelloClients: make(map[HandlerClient]time.Time),
dialoutSessions: make(map[*ClientSession]bool),
remoteSessions: make(map[*RemoteSession]bool),
backendTimeout: backendTimeout,
backend: backend,
@ -584,6 +586,22 @@ func (h *Hub) GetSessionByPublicId(sessionId string) Session {
return session
}
func (h *Hub) GetSessionByResumeId(resumeId string) Session {
data := h.decodeSessionId(resumeId, privateSessionName)
if data == nil {
return nil
}
h.mu.RLock()
defer h.mu.RUnlock()
session := h.sessions[data.Sid]
if session != nil && session.PrivateId() != resumeId {
// Session was created on different server.
return nil
}
return session
}
func (h *Hub) GetDialoutSession(roomId string, backend *Backend) *ClientSession {
url := backend.Url()
@ -690,15 +708,13 @@ func (h *Hub) startWaitAnonymousSessionRoomLocked(session *ClientSession) {
h.anonymousSessions[session] = now.Add(anonmyousJoinRoomTimeout)
}
func (h *Hub) startExpectHello(client *Client) {
func (h *Hub) startExpectHello(client HandlerClient) {
h.mu.Lock()
defer h.mu.Unlock()
if !client.IsConnected() {
return
}
client.mu.Lock()
defer client.mu.Unlock()
if client.IsAuthenticated() {
return
}
@ -708,15 +724,39 @@ func (h *Hub) startExpectHello(client *Client) {
h.expectHelloClients[client] = now.Add(initialHelloTimeout)
}
func (h *Hub) processNewClient(client *Client) {
func (h *Hub) processNewClient(client HandlerClient) {
h.startExpectHello(client)
h.sendWelcome(client)
}
func (h *Hub) sendWelcome(client *Client) {
func (h *Hub) sendWelcome(client HandlerClient) {
client.SendMessage(h.getWelcomeMessage())
}
func (h *Hub) registerClient(client HandlerClient) uint64 {
sid := h.sid.Add(1)
for sid == 0 {
sid = h.sid.Add(1)
}
h.mu.Lock()
defer h.mu.Unlock()
h.clients[sid] = client
return sid
}
func (h *Hub) unregisterClient(sid uint64) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.clients, sid)
}
func (h *Hub) unregisterRemoteSession(session *RemoteSession) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.remoteSessions, session)
}
func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {
sid := h.sid.Add(1)
for sid == 0 {
@ -730,17 +770,24 @@ func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {
return sessionIdData
}
func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *Backend, auth *BackendClientResponse) {
if !client.IsConnected() {
func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *Backend, auth *BackendClientResponse) {
if !c.IsConnected() {
// Client disconnected while waiting for "hello" response.
return
}
if auth.Type == "error" {
client.SendMessage(message.NewErrorServerMessage(auth.Error))
c.SendMessage(message.NewErrorServerMessage(auth.Error))
return
} else if auth.Type != "auth" {
client.SendMessage(message.NewErrorServerMessage(UserAuthFailed))
c.SendMessage(message.NewErrorServerMessage(UserAuthFailed))
return
}
client, ok := c.(*Client)
if !ok {
log.Printf("Can't register non-client %T", c)
client.SendMessage(message.NewWrappedErrorServerMessage(errors.New("can't register non-client")))
return
}
@ -844,7 +891,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
h.sendHelloResponse(session, message)
}
func (h *Hub) processUnregister(client *Client) *ClientSession {
func (h *Hub) processUnregister(client HandlerClient) Session {
session := client.GetSession()
h.mu.Lock()
@ -857,14 +904,18 @@ func (h *Hub) processUnregister(client *Client) *ClientSession {
h.mu.Unlock()
if session != nil {
log.Printf("Unregister %s (private=%s)", session.PublicId(), session.PrivateId())
session.ClearClient(client)
if c, ok := client.(*Client); ok {
if cs, ok := session.(*ClientSession); ok {
cs.ClearClient(c)
}
}
}
client.Close()
return session
}
func (h *Hub) processMessage(client *Client, data []byte) {
func (h *Hub) processMessage(client HandlerClient, data []byte) {
var message ClientMessage
if err := message.UnmarshalJSON(data); err != nil {
if session := client.GetSession(); session != nil {
@ -944,12 +995,97 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage)
return session.SendMessage(response)
}
func (h *Hub) processHello(client *Client, message *ClientMessage) {
type remoteClientInfo struct {
client *GrpcClient
response *LookupResumeIdReply
}
func (h *Hub) tryProxyResume(c HandlerClient, resumeId string, message *ClientMessage) bool {
client, ok := c.(*Client)
if !ok {
return false
}
var clients []*GrpcClient
if h.rpcClients != nil {
clients = h.rpcClients.GetClients()
}
if len(clients) == 0 {
return false
}
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer rpcCancel()
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(rpcCtx)
defer cancel()
var remoteClient atomic.Pointer[remoteClientInfo]
for _, c := range clients {
wg.Add(1)
go func(client *GrpcClient) {
defer wg.Done()
if client.IsSelf() {
return
}
response, err := client.LookupResumeId(ctx, resumeId)
if err != nil {
log.Printf("Could not lookup resume id %s on %s: %s", resumeId, client.Target(), err)
return
}
cancel()
remoteClient.CompareAndSwap(nil, &remoteClientInfo{
client: client,
response: response,
})
}(c)
}
wg.Wait()
if !client.IsConnected() {
// Client disconnected while checking message.
return false
}
info := remoteClient.Load()
if info == nil {
return false
}
rs, err := NewRemoteSession(h, client, info.client, info.response.SessionId)
if err != nil {
log.Printf("Could not create remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err)
return false
}
if err := rs.Start(message); err != nil {
rs.Close()
log.Printf("Could not start remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err)
return false
}
log.Printf("Proxy session %s to %s", info.response.SessionId, info.client.Target())
h.mu.Lock()
defer h.mu.Unlock()
h.remoteSessions[rs] = true
delete(h.expectHelloClients, client)
return true
}
func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
resumeId := message.Hello.ResumeId
if resumeId != "" {
data := h.decodeSessionId(resumeId, privateSessionName)
if data == nil {
statsHubSessionResumeFailed.Inc()
if h.tryProxyResume(client, resumeId, message) {
return
}
client.SendMessage(message.NewErrorServerMessage(NoSuchSession))
return
}
@ -959,6 +1095,10 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) {
if !found || resumeId != session.PrivateId() {
h.mu.Unlock()
statsHubSessionResumeFailed.Inc()
if h.tryProxyResume(client, resumeId, message) {
return
}
client.SendMessage(message.NewErrorServerMessage(NoSuchSession))
return
}
@ -1013,7 +1153,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) {
}
}
func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
url := message.Hello.Auth.parsedUrl
backend := h.backend.GetBackend(url)
if backend == nil {
@ -1035,7 +1175,7 @@ func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend,
return backend, &auth, nil
}
func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
url := message.Hello.Auth.parsedUrl
backend := h.backend.GetBackend(url)
if backend == nil {
@ -1141,11 +1281,11 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend,
return backend, auth, nil
}
func (h *Hub) processHelloClient(client *Client, message *ClientMessage) {
func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) {
// Make sure the client must send another "hello" in case of errors.
defer h.startExpectHello(client)
var authFunc func(*Client, *ClientMessage) (*Backend, *BackendClientResponse, error)
var authFunc func(HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error)
switch message.Hello.Version {
case HelloVersionV1:
// Auth information contains a ticket that must be validated against the
@ -1172,7 +1312,7 @@ func (h *Hub) processHelloClient(client *Client, message *ClientMessage) {
h.processRegister(client, message, backend, auth)
}
func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) {
func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage) {
defer h.startExpectHello(client)
if len(h.internalClientsSecret) == 0 {
client.SendMessage(message.NewErrorServerMessage(InvalidClientType))
@ -1261,8 +1401,12 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo
return session.SendMessage(response)
}
func (h *Hub) processRoom(client *Client, message *ClientMessage) {
session := client.GetSession()
func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
session, ok := client.GetSession().(*ClientSession)
if !ok {
return
}
roomId := message.Room.RoomId
if roomId == "" {
if session == nil {
@ -1281,29 +1425,34 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) {
return
}
if session != nil {
if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) {
// Session already is in that room, no action needed.
roomSessionId := message.Room.SessionId
if roomSessionId == "" {
// TODO(jojo): Better make the session id required in the request.
log.Printf("User did not send a room session id, assuming session %s", session.PublicId())
roomSessionId = session.PublicId()
}
if session == nil {
session.SendMessage(message.NewErrorServerMessage(
NewError("not_authenticated", "Need to authenticate before joining rooms."),
))
return
}
if err := session.UpdateRoomSessionId(roomSessionId); err != nil {
log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err)
}
session.SendMessage(message.NewErrorServerMessage(
NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{
Room: &RoomServerMessage{
RoomId: room.id,
Properties: room.properties,
},
}),
))
return
if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) {
// Session already is in that room, no action needed.
roomSessionId := message.Room.SessionId
if roomSessionId == "" {
// TODO(jojo): Better make the session id required in the request.
log.Printf("User did not send a room session id, assuming session %s", session.PublicId())
roomSessionId = session.PublicId()
}
if err := session.UpdateRoomSessionId(roomSessionId); err != nil {
log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err)
}
session.SendMessage(message.NewErrorServerMessage(
NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{
Room: &RoomServerMessage{
RoomId: room.id,
Properties: room.properties,
},
}),
))
return
}
var room BackendClientResponse
@ -1430,14 +1579,14 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
r.AddSession(session, room.Room.Session)
}
func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
msg := message.Message
session := client.GetSession()
if session == nil {
func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
session, ok := client.GetSession().(*ClientSession)
if session == nil || !ok {
// Client is not connected yet.
return
}
msg := message.Message
var recipient *ClientSession
var subject string
var clientData *MessageClientMessageData
@ -1484,10 +1633,10 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
// User is stopping to share his screen. Firefox doesn't properly clean
// up the peer connections in all cases, so make sure to stop publishing
// in the MCU.
go func(c *Client) {
go func(c HandlerClient) {
time.Sleep(cleanupScreenPublisherDelay)
session := c.GetSession()
if session == nil {
session, ok := c.GetSession().(*ClientSession)
if session == nil || !ok {
return
}
@ -1700,7 +1849,7 @@ func isAllowedToControl(session Session) bool {
return false
}
func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) {
msg := message.Control
session := client.GetSession()
if session == nil {
@ -1813,10 +1962,10 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
}
}
func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) {
func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) {
msg := message.Internal
session := client.GetSession()
if session == nil {
session, ok := client.GetSession().(*ClientSession)
if session == nil || !ok {
// Client is not connected yet.
return
} else if session.ClientType() != HelloClientTypeInternal {
@ -2030,7 +2179,7 @@ func isAllowedToUpdateTransientData(session Session) bool {
return false
}
func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) {
func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage) {
msg := message.TransientData
session := client.GetSession()
if session == nil {
@ -2070,17 +2219,17 @@ func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) {
}
}
func sendNotAllowed(session *ClientSession, message *ClientMessage, reason string) {
func sendNotAllowed(session Session, message *ClientMessage, reason string) {
response := message.NewErrorServerMessage(NewError("not_allowed", reason))
session.SendMessage(response)
}
func sendMcuClientNotFound(session *ClientSession, message *ClientMessage) {
func sendMcuClientNotFound(session Session, message *ClientMessage) {
response := message.NewErrorServerMessage(NewError("client_not_found", "No MCU client found to send message to."))
session.SendMessage(response)
}
func sendMcuProcessingFailed(session *ClientSession, message *ClientMessage) {
func sendMcuProcessingFailed(session Session, message *ClientMessage) {
response := message.NewErrorServerMessage(NewError("processing_failed", "Processing of the message failed, please check server logs."))
session.SendMessage(response)
}
@ -2295,7 +2444,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient McuClient
session.SendMessage(response_message)
}
func (h *Hub) processByeMsg(client *Client, message *ClientMessage) {
func (h *Hub) processByeMsg(client HandlerClient, message *ClientMessage) {
client.SendByeResponse(message)
if session := h.processUnregister(client); session != nil {
session.Close()
@ -2412,7 +2561,7 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
}(h)
}
func (h *Hub) OnLookupCountry(client *Client) string {
func (h *Hub) OnLookupCountry(client HandlerClient) string {
ip := net.ParseIP(client.RemoteAddr())
if ip == nil {
return noCountry
@ -2444,14 +2593,14 @@ func (h *Hub) OnLookupCountry(client *Client) string {
return country
}
func (h *Hub) OnClosed(client *Client) {
func (h *Hub) OnClosed(client HandlerClient) {
h.processUnregister(client)
}
func (h *Hub) OnMessageReceived(client *Client, data []byte) {
func (h *Hub) OnMessageReceived(client HandlerClient, data []byte) {
h.processMessage(client, data)
}
func (h *Hub) OnRTTReceived(client *Client, rtt time.Duration) {
func (h *Hub) OnRTTReceived(client HandlerClient, rtt time.Duration) {
// Ignore
}

View file

@ -274,13 +274,19 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
h.mu.Lock()
clients := len(h.clients)
sessions := len(h.sessions)
remoteSessions := len(h.remoteSessions)
h.mu.Unlock()
h.ru.Lock()
rooms := len(h.rooms)
h.ru.Unlock()
readActive := h.readPumpActive.Load()
writeActive := h.writePumpActive.Load()
if clients == 0 && rooms == 0 && sessions == 0 && readActive == 0 && writeActive == 0 {
if clients == 0 &&
rooms == 0 &&
sessions == 0 &&
remoteSessions == 0 &&
readActive == 0 &&
writeActive == 0 {
break
}
@ -289,7 +295,7 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
h.mu.Lock()
h.ru.Lock()
dumpGoroutines("", os.Stderr)
t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, readActive, writeActive, ctx.Err())
t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / remoteSessions %v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, h.remoteSessions, readActive, writeActive, ctx.Err())
h.ru.Unlock()
h.mu.Unlock()
return
@ -1892,6 +1898,307 @@ func TestClientHelloResumeAndJoin(t *testing.T) {
}
}
func runGrpcProxyTest(t *testing.T, f func(hub1, hub2 *Hub, server1, server2 *httptest.Server)) {
t.Helper()
var hub1 *Hub
var hub2 *Hub
var server1 *httptest.Server
var server2 *httptest.Server
var router1 *mux.Router
var router2 *mux.Router
hub1, hub2, router1, router2, server1, server2 = CreateClusteredHubsForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
// Make sure all backends use the same server
if server1 == nil {
server1 = server
} else {
server = server1
}
config, err := getTestConfig(server)
if err != nil {
return nil, err
}
config.RemoveOption("backend", "allowed")
config.RemoveOption("backend", "secret")
config.AddOption("backend", "backends", "backend1")
config.AddOption("backend1", "url", server.URL)
config.AddOption("backend1", "secret", string(testBackendSecret))
config.AddOption("backend1", "sessionlimit", "1")
return config, nil
})
registerBackendHandlerUrl(t, router1, "/")
registerBackendHandlerUrl(t, router2, "/")
f(hub1, hub2, server1, server2)
}
func TestClientHelloResumeProxy(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
client1 := NewTestClient(t, server1, hub1)
defer client1.CloseWithBye()
if err := client1.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
if hello.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", hello.Hello)
}
}
client1.Close()
if err := client1.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
client2 := NewTestClient(t, server2, hub2)
defer client2.CloseWithBye()
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
t.Fatal(err)
}
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if hello2.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
}
if hello2.Hello.SessionId != hello.Hello.SessionId {
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
}
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
}
}
// Join room by id.
roomId := "test-room"
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
// We will receive a "joined" event.
if err := client2.RunUntilJoined(ctx, hello.Hello); err != nil {
t.Error(err)
}
if room := hub1.getRoom(roomId); room == nil {
t.Fatalf("Could not find room %s", roomId)
}
if room := hub2.getRoom(roomId); room != nil {
t.Fatalf("Should not have gotten room %s, got %+v", roomId, room)
}
users := []map[string]interface{}{
{
"sessionId": "the-session-id",
"inCall": 1,
},
}
room := hub1.getRoom(roomId)
if room == nil {
t.Fatalf("Could not find room %s", roomId)
}
room.PublishUsersInCallChanged(users, users)
if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil {
t.Error(err)
}
})
})
}
func TestClientHelloResumeProxy_Takeover(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
client1 := NewTestClient(t, server1, hub1)
defer client1.CloseWithBye()
if err := client1.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
if hello.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", hello.Hello)
}
}
client2 := NewTestClient(t, server2, hub2)
defer client2.CloseWithBye()
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
t.Fatal(err)
}
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if hello2.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
}
if hello2.Hello.SessionId != hello.Hello.SessionId {
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
}
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
}
}
// The first client got disconnected with a reason in a "Bye" message.
if msg, err := client1.RunUntilMessage(ctx); err != nil {
t.Error(err)
} else {
if msg.Type != "bye" || msg.Bye == nil {
t.Errorf("Expected bye message, got %+v", msg)
} else if msg.Bye.Reason != "session_resumed" {
t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason)
}
}
if msg, err := client1.RunUntilMessage(ctx); err == nil {
t.Errorf("Expected error but received %+v", msg)
} else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
t.Errorf("Expected close error but received %+v", err)
}
client3 := NewTestClient(t, server1, hub1)
defer client3.CloseWithBye()
if err := client3.SendHelloResume(hello.Hello.ResumeId); err != nil {
t.Fatal(err)
}
hello3, err := client3.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if hello3.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
}
if hello3.Hello.SessionId != hello.Hello.SessionId {
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
}
if hello3.Hello.ResumeId != hello.Hello.ResumeId {
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
}
}
// The second client got disconnected with a reason in a "Bye" message.
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Error(err)
} else {
if msg.Type != "bye" || msg.Bye == nil {
t.Errorf("Expected bye message, got %+v", msg)
} else if msg.Bye.Reason != "session_resumed" {
t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason)
}
}
if msg, err := client2.RunUntilMessage(ctx); err == nil {
t.Errorf("Expected error but received %+v", msg)
} else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
t.Errorf("Expected close error but received %+v", err)
}
})
})
}
func TestClientHelloResumeProxy_Disconnect(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
client1 := NewTestClient(t, server1, hub1)
defer client1.CloseWithBye()
if err := client1.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
if hello.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", hello.Hello)
}
}
client1.Close()
if err := client1.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
client2 := NewTestClient(t, server2, hub2)
defer client2.CloseWithBye()
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
t.Fatal(err)
}
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if hello2.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
}
if hello2.Hello.SessionId != hello.Hello.SessionId {
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
}
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
}
}
// Simulate unclean shutdown of second instance.
hub2.rpcServer.conn.Stop()
if err := client2.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
})
})
}
func TestClientHelloClient(t *testing.T) {
hub, _, _, server := CreateHubForTest(t)

View file

@ -53,18 +53,18 @@ func (c *ProxyClient) SetSession(session *ProxySession) {
c.session.Store(session)
}
func (c *ProxyClient) OnClosed(client *signaling.Client) {
func (c *ProxyClient) OnClosed(client signaling.HandlerClient) {
if session := c.GetSession(); session != nil {
session.MarkUsed()
}
c.proxy.clientClosed(&c.Client)
}
func (c *ProxyClient) OnMessageReceived(client *signaling.Client, data []byte) {
func (c *ProxyClient) OnMessageReceived(client signaling.HandlerClient, data []byte) {
c.proxy.processMessage(c, data)
}
func (c *ProxyClient) OnRTTReceived(client *signaling.Client, rtt time.Duration) {
func (c *ProxyClient) OnRTTReceived(client signaling.HandlerClient, rtt time.Duration) {
if session := c.GetSession(); session != nil {
session.MarkUsed()
}

152
remotesession.go Normal file
View file

@ -0,0 +1,152 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"sync/atomic"
"time"
)
type RemoteSession struct {
hub *Hub
client *Client
remoteClient *GrpcClient
sessionId string
proxy atomic.Pointer[SessionProxy]
}
func NewRemoteSession(hub *Hub, client *Client, remoteClient *GrpcClient, sessionId string) (*RemoteSession, error) {
remoteSession := &RemoteSession{
hub: hub,
client: client,
remoteClient: remoteClient,
sessionId: sessionId,
}
client.SetSessionId(sessionId)
client.SetHandler(remoteSession)
proxy, err := remoteClient.ProxySession(context.Background(), sessionId, remoteSession)
if err != nil {
return nil, err
}
remoteSession.proxy.Store(proxy)
return remoteSession, nil
}
func (s *RemoteSession) Country() string {
return s.client.Country()
}
func (s *RemoteSession) RemoteAddr() string {
return s.client.RemoteAddr()
}
func (s *RemoteSession) UserAgent() string {
return s.client.UserAgent()
}
func (s *RemoteSession) IsConnected() bool {
return true
}
func (s *RemoteSession) Start(message *ClientMessage) error {
return s.sendMessage(message)
}
func (s *RemoteSession) OnProxyMessage(msg *ServerSessionMessage) error {
var message *ServerMessage
if err := json.Unmarshal(msg.Message, &message); err != nil {
return err
}
if !s.client.SendMessage(message) {
return fmt.Errorf("could not send message to client")
}
return nil
}
func (s *RemoteSession) OnProxyClose(err error) {
if err != nil {
log.Printf("Proxy connection for session %s to %s was closed with error: %s", s.sessionId, s.remoteClient.Target(), err)
}
s.Close()
}
func (s *RemoteSession) SendMessage(message WritableClientMessage) bool {
return s.sendMessage(message) == nil
}
func (s *RemoteSession) sendProxyMessage(message []byte) error {
proxy := s.proxy.Load()
if proxy == nil {
return errors.New("proxy already closed")
}
msg := &ClientSessionMessage{
Message: message,
}
return proxy.Send(msg)
}
func (s *RemoteSession) sendMessage(message interface{}) error {
data, err := json.Marshal(message)
if err != nil {
return err
}
return s.sendProxyMessage(data)
}
func (s *RemoteSession) Close() {
if proxy := s.proxy.Swap(nil); proxy != nil {
proxy.Close()
}
s.hub.unregisterRemoteSession(s)
s.client.Close()
}
func (s *RemoteSession) OnLookupCountry(client HandlerClient) string {
return s.hub.OnLookupCountry(client)
}
func (s *RemoteSession) OnClosed(client HandlerClient) {
s.Close()
}
func (s *RemoteSession) OnMessageReceived(client HandlerClient, message []byte) {
if err := s.sendProxyMessage(message); err != nil {
log.Printf("Error sending %s to the proxy for session %s: %s", string(message), s.sessionId, err)
s.Close()
}
}
func (s *RemoteSession) OnRTTReceived(client HandlerClient, rtt time.Duration) {
}

View file

@ -86,6 +86,14 @@ func (s *DummySession) HasPermission(permission Permission) bool {
return false
}
func (s *DummySession) SendError(e *Error) bool {
return false
}
func (s *DummySession) SendMessage(message *ServerMessage) bool {
return false
}
func checkSession(t *testing.T, sessions RoomSessions, sessionId string, roomSessionId string) Session {
session := &DummySession{
publicId: sessionId,

View file

@ -72,4 +72,7 @@ type Session interface {
Close()
HasPermission(permission Permission) bool
SendError(e *Error) bool
SendMessage(message *ServerMessage) bool
}

View file

@ -311,12 +311,14 @@ func (c *TestClient) WaitForClientRemoved(ctx context.Context) error {
for {
found := false
for _, client := range c.hub.clients {
client.mu.Lock()
conn := client.conn
client.mu.Unlock()
if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() {
found = true
break
if cc, ok := client.(*Client); ok {
cc.mu.Lock()
conn := cc.conn
cc.mu.Unlock()
if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() {
found = true
break
}
}
}
if !found {
@ -493,7 +495,7 @@ func (c *TestClient) SendHelloParams(url string, version string, clientType stri
Hello: &HelloClientMessage{
Version: version,
Features: features,
Auth: HelloClientMessageAuth{
Auth: &HelloClientMessageAuth{
Type: clientType,
Url: url,
Params: (*json.RawMessage)(&data),

View file

@ -51,7 +51,7 @@ type VirtualSession struct {
options *AddSessionOptions
}
func GetVirtualSessionId(session *ClientSession, sessionId string) string {
func GetVirtualSessionId(session Session, sessionId string) string {
return session.PublicId() + "|" + sessionId
}
@ -163,7 +163,7 @@ func (s *VirtualSession) Close() {
s.CloseWithFeedback(nil, nil)
}
func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *ClientMessage) {
func (s *VirtualSession) CloseWithFeedback(session Session, message *ClientMessage) {
room := s.GetRoom()
s.session.RemoveVirtualSession(s)
removed := s.session.hub.removeSession(s)
@ -173,7 +173,7 @@ func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *Clie
s.session.events.UnregisterSessionListener(s.PublicId(), s.session.Backend(), s)
}
func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession, message *ClientMessage) {
func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, message *ClientMessage) {
ctx, cancel := context.WithTimeout(context.Background(), s.hub.backendTimeout)
defer cancel()
@ -321,3 +321,11 @@ func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) {
}
}
}
func (s *VirtualSession) SendError(e *Error) bool {
return s.session.SendError(e)
}
func (s *VirtualSession) SendMessage(message *ServerMessage) bool {
return s.session.SendMessage(message)
}