mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-06-10 09:52:12 +02:00
Merge pull request #715 from strukturag/resume-remote
Support resuming remote sessions
This commit is contained in:
commit
d368a060fa
|
@ -387,7 +387,7 @@ type HelloClientMessage struct {
|
|||
Features []string `json:"features,omitempty"`
|
||||
|
||||
// The authentication credentials.
|
||||
Auth HelloClientMessageAuth `json:"auth"`
|
||||
Auth *HelloClientMessageAuth `json:"auth,omitempty"`
|
||||
}
|
||||
|
||||
func (m *HelloClientMessage) CheckValid() error {
|
||||
|
@ -395,7 +395,7 @@ func (m *HelloClientMessage) CheckValid() error {
|
|||
return InvalidHelloVersion
|
||||
}
|
||||
if m.ResumeId == "" {
|
||||
if m.Auth.Params == nil || len(*m.Auth.Params) == 0 {
|
||||
if m.Auth == nil || m.Auth.Params == nil || len(*m.Auth.Params) == 0 {
|
||||
return fmt.Errorf("params missing")
|
||||
}
|
||||
if m.Auth.Type == "" {
|
||||
|
|
|
@ -95,14 +95,14 @@ func TestHelloClientMessage(t *testing.T) {
|
|||
// Hello version 1
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: &json.RawMessage{'{', '}'},
|
||||
Url: "https://domain.invalid",
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Type: "client",
|
||||
Params: &json.RawMessage{'{', '}'},
|
||||
Url: "https://domain.invalid",
|
||||
|
@ -110,7 +110,7 @@ func TestHelloClientMessage(t *testing.T) {
|
|||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Type: "internal",
|
||||
Params: (*json.RawMessage)(&internalAuthParams),
|
||||
},
|
||||
|
@ -122,14 +122,14 @@ func TestHelloClientMessage(t *testing.T) {
|
|||
// Hello version 2
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV2,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: (*json.RawMessage)(&tokenAuthParams),
|
||||
Url: "https://domain.invalid",
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV2,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Type: "client",
|
||||
Params: (*json.RawMessage)(&tokenAuthParams),
|
||||
Url: "https://domain.invalid",
|
||||
|
@ -147,40 +147,40 @@ func TestHelloClientMessage(t *testing.T) {
|
|||
&HelloClientMessage{Version: HelloVersionV1},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: &json.RawMessage{'{', '}'},
|
||||
Type: "invalid-type",
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Url: "https://domain.invalid",
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: &json.RawMessage{'{', '}'},
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: &json.RawMessage{'{', '}'},
|
||||
Url: "invalid-url",
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Type: "internal",
|
||||
Params: &json.RawMessage{'{', '}'},
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV1,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Type: "internal",
|
||||
Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON.
|
||||
},
|
||||
|
@ -188,33 +188,33 @@ func TestHelloClientMessage(t *testing.T) {
|
|||
// Hello version 2
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV2,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Url: "https://domain.invalid",
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV2,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: (*json.RawMessage)(&tokenAuthParams),
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV2,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: (*json.RawMessage)(&tokenAuthParams),
|
||||
Url: "invalid-url",
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV2,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: (*json.RawMessage)(&internalAuthParams),
|
||||
Url: "https://domain.invalid",
|
||||
},
|
||||
},
|
||||
&HelloClientMessage{
|
||||
Version: HelloVersionV2,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Params: &json.RawMessage{'x', 'y', 'z'}, // Invalid JSON.
|
||||
Url: "https://domain.invalid",
|
||||
},
|
||||
|
|
107
client.go
107
client.go
|
@ -92,14 +92,32 @@ type WritableClientMessage interface {
|
|||
CloseAfterSend(session Session) bool
|
||||
}
|
||||
|
||||
type HandlerClient interface {
|
||||
RemoteAddr() string
|
||||
Country() string
|
||||
UserAgent() string
|
||||
IsConnected() bool
|
||||
IsAuthenticated() bool
|
||||
|
||||
GetSession() Session
|
||||
SetSession(session Session)
|
||||
|
||||
SendError(e *Error) bool
|
||||
SendByeResponse(message *ClientMessage) bool
|
||||
SendByeResponseWithReason(message *ClientMessage, reason string) bool
|
||||
SendMessage(message WritableClientMessage) bool
|
||||
|
||||
Close()
|
||||
}
|
||||
|
||||
type ClientHandler interface {
|
||||
OnClosed(*Client)
|
||||
OnMessageReceived(*Client, []byte)
|
||||
OnRTTReceived(*Client, time.Duration)
|
||||
OnClosed(HandlerClient)
|
||||
OnMessageReceived(HandlerClient, []byte)
|
||||
OnRTTReceived(HandlerClient, time.Duration)
|
||||
}
|
||||
|
||||
type ClientGeoIpHandler interface {
|
||||
OnLookupCountry(*Client) string
|
||||
OnLookupCountry(HandlerClient) string
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
|
@ -111,7 +129,8 @@ type Client struct {
|
|||
country *string
|
||||
logRTT bool
|
||||
|
||||
session atomic.Pointer[ClientSession]
|
||||
session atomic.Pointer[Session]
|
||||
sessionId atomic.Pointer[string]
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
|
@ -142,12 +161,16 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler
|
|||
func (c *Client) SetConn(conn *websocket.Conn, remoteAddress string, handler ClientHandler) {
|
||||
c.conn = conn
|
||||
c.addr = remoteAddress
|
||||
c.handler = handler
|
||||
c.SetHandler(handler)
|
||||
c.closer = NewCloser()
|
||||
c.messageChan = make(chan *bytes.Buffer, 16)
|
||||
c.messagesDone = make(chan struct{})
|
||||
}
|
||||
|
||||
func (c *Client) SetHandler(handler ClientHandler) {
|
||||
c.handler = handler
|
||||
}
|
||||
|
||||
func (c *Client) IsConnected() bool {
|
||||
return c.closed.Load() == 0
|
||||
}
|
||||
|
@ -156,12 +179,39 @@ func (c *Client) IsAuthenticated() bool {
|
|||
return c.GetSession() != nil
|
||||
}
|
||||
|
||||
func (c *Client) GetSession() *ClientSession {
|
||||
return c.session.Load()
|
||||
func (c *Client) GetSession() Session {
|
||||
session := c.session.Load()
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return *session
|
||||
}
|
||||
|
||||
func (c *Client) SetSession(session *ClientSession) {
|
||||
c.session.Store(session)
|
||||
func (c *Client) SetSession(session Session) {
|
||||
if session == nil {
|
||||
c.session.Store(nil)
|
||||
} else {
|
||||
c.session.Store(&session)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) SetSessionId(sessionId string) {
|
||||
c.sessionId.Store(&sessionId)
|
||||
}
|
||||
|
||||
func (c *Client) GetSessionId() string {
|
||||
sessionId := c.sessionId.Load()
|
||||
if sessionId == nil {
|
||||
session := c.GetSession()
|
||||
if session == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return session.PublicId()
|
||||
}
|
||||
|
||||
return *sessionId
|
||||
}
|
||||
|
||||
func (c *Client) RemoteAddr() string {
|
||||
|
@ -234,12 +284,14 @@ func (c *Client) SendByeResponse(message *ClientMessage) bool {
|
|||
func (c *Client) SendByeResponseWithReason(message *ClientMessage, reason string) bool {
|
||||
response := &ServerMessage{
|
||||
Type: "bye",
|
||||
Bye: &ByeServerMessage{},
|
||||
}
|
||||
if message != nil {
|
||||
response.Id = message.Id
|
||||
}
|
||||
if reason != "" {
|
||||
if response.Bye == nil {
|
||||
response.Bye = &ByeServerMessage{}
|
||||
}
|
||||
response.Bye.Reason = reason
|
||||
}
|
||||
return c.SendMessage(response)
|
||||
|
@ -277,8 +329,8 @@ func (c *Client) ReadPump() {
|
|||
rtt := now.Sub(time.Unix(0, ts))
|
||||
if c.logRTT {
|
||||
rtt_ms := rtt.Nanoseconds() / time.Millisecond.Nanoseconds()
|
||||
if session := c.GetSession(); session != nil {
|
||||
log.Printf("Client %s has RTT of %d ms (%s)", session.PublicId(), rtt_ms, rtt)
|
||||
if sessionId := c.GetSessionId(); sessionId != "" {
|
||||
log.Printf("Client %s has RTT of %d ms (%s)", sessionId, rtt_ms, rtt)
|
||||
} else {
|
||||
log.Printf("Client from %s has RTT of %d ms (%s)", addr, rtt_ms, rtt)
|
||||
}
|
||||
|
@ -296,8 +348,8 @@ func (c *Client) ReadPump() {
|
|||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived) {
|
||||
if session := c.GetSession(); session != nil {
|
||||
log.Printf("Error reading from client %s: %v", session.PublicId(), err)
|
||||
if sessionId := c.GetSessionId(); sessionId != "" {
|
||||
log.Printf("Error reading from client %s: %v", sessionId, err)
|
||||
} else {
|
||||
log.Printf("Error reading from %s: %v", addr, err)
|
||||
}
|
||||
|
@ -306,8 +358,8 @@ func (c *Client) ReadPump() {
|
|||
}
|
||||
|
||||
if messageType != websocket.TextMessage {
|
||||
if session := c.GetSession(); session != nil {
|
||||
log.Printf("Unsupported message type %v from client %s", messageType, session.PublicId())
|
||||
if sessionId := c.GetSessionId(); sessionId != "" {
|
||||
log.Printf("Unsupported message type %v from client %s", messageType, sessionId)
|
||||
} else {
|
||||
log.Printf("Unsupported message type %v from %s", messageType, addr)
|
||||
}
|
||||
|
@ -319,8 +371,8 @@ func (c *Client) ReadPump() {
|
|||
decodeBuffer.Reset()
|
||||
if _, err := decodeBuffer.ReadFrom(reader); err != nil {
|
||||
bufferPool.Put(decodeBuffer)
|
||||
if session := c.GetSession(); session != nil {
|
||||
log.Printf("Error reading message from client %s: %v", session.PublicId(), err)
|
||||
if sessionId := c.GetSessionId(); sessionId != "" {
|
||||
log.Printf("Error reading message from client %s: %v", sessionId, err)
|
||||
} else {
|
||||
log.Printf("Error reading message from %s: %v", addr, err)
|
||||
}
|
||||
|
@ -373,8 +425,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if session := c.GetSession(); session != nil {
|
||||
log.Printf("Could not send message %+v to client %s: %v", message, session.PublicId(), err)
|
||||
if sessionId := c.GetSessionId(); sessionId != "" {
|
||||
log.Printf("Could not send message %+v to client %s: %v", message, sessionId, err)
|
||||
} else {
|
||||
log.Printf("Could not send message %+v to %s: %v", message, c.RemoteAddr(), err)
|
||||
}
|
||||
|
@ -386,8 +438,8 @@ func (c *Client) writeInternal(message json.Marshaler) bool {
|
|||
close:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
|
||||
if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil {
|
||||
if session := c.GetSession(); session != nil {
|
||||
log.Printf("Could not send close message to client %s: %v", session.PublicId(), err)
|
||||
if sessionId := c.GetSessionId(); sessionId != "" {
|
||||
log.Printf("Could not send close message to client %s: %v", sessionId, err)
|
||||
} else {
|
||||
log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err)
|
||||
}
|
||||
|
@ -413,8 +465,8 @@ func (c *Client) writeError(e error) bool { // nolint
|
|||
closeData := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, e.Error())
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
|
||||
if err := c.conn.WriteMessage(websocket.CloseMessage, closeData); err != nil {
|
||||
if session := c.GetSession(); session != nil {
|
||||
log.Printf("Could not send close message to client %s: %v", session.PublicId(), err)
|
||||
if sessionId := c.GetSessionId(); sessionId != "" {
|
||||
log.Printf("Could not send close message to client %s: %v", sessionId, err)
|
||||
} else {
|
||||
log.Printf("Could not send close message to %s: %v", c.RemoteAddr(), err)
|
||||
}
|
||||
|
@ -445,7 +497,6 @@ func (c *Client) writeMessageLocked(message WritableClientMessage) bool {
|
|||
go session.Close()
|
||||
}
|
||||
go c.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
@ -462,8 +513,8 @@ func (c *Client) sendPing() bool {
|
|||
msg := strconv.FormatInt(now, 10)
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // nolint
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, []byte(msg)); err != nil {
|
||||
if session := c.GetSession(); session != nil {
|
||||
log.Printf("Could not send ping to client %s: %v", session.PublicId(), err)
|
||||
if sessionId := c.GetSessionId(); sessionId != "" {
|
||||
log.Printf("Could not send ping to client %s: %v", sessionId, err)
|
||||
} else {
|
||||
log.Printf("Could not send ping to %s: %v", c.RemoteAddr(), err)
|
||||
}
|
||||
|
|
|
@ -601,7 +601,7 @@ func main() {
|
|||
Type: "hello",
|
||||
Hello: &signaling.HelloClientMessage{
|
||||
Version: signaling.HelloVersionV1,
|
||||
Auth: signaling.HelloClientMessageAuth{
|
||||
Auth: &signaling.HelloClientMessageAuth{
|
||||
Url: backendUrl + "/auth",
|
||||
Params: &json.RawMessage{'{', '}'},
|
||||
},
|
||||
|
|
|
@ -67,7 +67,7 @@ type ClientSession struct {
|
|||
|
||||
mu sync.Mutex
|
||||
|
||||
client *Client
|
||||
client HandlerClient
|
||||
room atomic.Pointer[Room]
|
||||
roomJoinTime atomic.Int64
|
||||
roomSessionId string
|
||||
|
@ -500,14 +500,14 @@ func (s *ClientSession) doUnsubscribeRoomEvents(notify bool) {
|
|||
s.roomSessionId = ""
|
||||
}
|
||||
|
||||
func (s *ClientSession) ClearClient(client *Client) {
|
||||
func (s *ClientSession) ClearClient(client HandlerClient) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.clearClientLocked(client)
|
||||
}
|
||||
|
||||
func (s *ClientSession) clearClientLocked(client *Client) {
|
||||
func (s *ClientSession) clearClientLocked(client HandlerClient) {
|
||||
if s.client == nil {
|
||||
return
|
||||
} else if client != nil && s.client != client {
|
||||
|
@ -520,18 +520,18 @@ func (s *ClientSession) clearClientLocked(client *Client) {
|
|||
prevClient.SetSession(nil)
|
||||
}
|
||||
|
||||
func (s *ClientSession) GetClient() *Client {
|
||||
func (s *ClientSession) GetClient() HandlerClient {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return s.getClientUnlocked()
|
||||
}
|
||||
|
||||
func (s *ClientSession) getClientUnlocked() *Client {
|
||||
func (s *ClientSession) getClientUnlocked() HandlerClient {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *ClientSession) SetClient(client *Client) *Client {
|
||||
func (s *ClientSession) SetClient(client HandlerClient) HandlerClient {
|
||||
if client == nil {
|
||||
panic("Use ClearClient to set the client to nil")
|
||||
}
|
||||
|
@ -1341,7 +1341,7 @@ func (s *ClientSession) filterAsyncMessage(msg *AsyncMessage) *ServerMessage {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *ClientSession) NotifySessionResumed(client *Client) {
|
||||
func (s *ClientSession) NotifySessionResumed(client HandlerClient) {
|
||||
s.mu.Lock()
|
||||
if len(s.pendingClientMessages) == 0 {
|
||||
s.mu.Unlock()
|
||||
|
|
104
grpc_client.go
104
grpc_client.go
|
@ -26,6 +26,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
|
@ -39,6 +40,7 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/resolver"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
@ -51,6 +53,8 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrNoSuchResumeId = fmt.Errorf("unknown resume id")
|
||||
|
||||
customResolverPrefix atomic.Uint64
|
||||
)
|
||||
|
||||
|
@ -185,6 +189,26 @@ func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) {
|
|||
return response.GetServerId(), nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) LookupResumeId(ctx context.Context, resumeId string) (*LookupResumeIdReply, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("LookupResumeId").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Lookup resume id %s on %s", resumeId, c.Target())
|
||||
response, err := c.impl.LookupResumeId(ctx, &LookupResumeIdRequest{
|
||||
ResumeId: resumeId,
|
||||
}, grpc.WaitForReady(true))
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
return nil, ErrNoSuchResumeId
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sessionId := response.GetSessionId(); sessionId == "" {
|
||||
return nil, ErrNoSuchResumeId
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) LookupSessionId(ctx context.Context, roomSessionId string, disconnectReason string) (string, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("LookupSessionId").Inc()
|
||||
// TODO: Remove debug logging
|
||||
|
@ -258,6 +282,86 @@ func (c *GrpcClient) GetSessionCount(ctx context.Context, u *url.URL) (uint32, e
|
|||
return response.GetCount(), nil
|
||||
}
|
||||
|
||||
type ProxySessionReceiver interface {
|
||||
RemoteAddr() string
|
||||
Country() string
|
||||
UserAgent() string
|
||||
|
||||
OnProxyMessage(message *ServerSessionMessage) error
|
||||
OnProxyClose(err error)
|
||||
}
|
||||
|
||||
type SessionProxy struct {
|
||||
sessionId string
|
||||
receiver ProxySessionReceiver
|
||||
|
||||
sendMu sync.Mutex
|
||||
client RpcSessions_ProxySessionClient
|
||||
}
|
||||
|
||||
func (p *SessionProxy) recvPump() {
|
||||
var closeError error
|
||||
defer func() {
|
||||
p.receiver.OnProxyClose(closeError)
|
||||
if err := p.Close(); err != nil {
|
||||
log.Printf("Error closing proxy for session %s: %s", p.sessionId, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
msg, err := p.client.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
log.Printf("Error receiving message from proxy for session %s: %s", p.sessionId, err)
|
||||
closeError = err
|
||||
break
|
||||
}
|
||||
|
||||
if err := p.receiver.OnProxyMessage(msg); err != nil {
|
||||
log.Printf("Error processing message %+v from proxy for session %s: %s", msg, p.sessionId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SessionProxy) Send(message *ClientSessionMessage) error {
|
||||
p.sendMu.Lock()
|
||||
defer p.sendMu.Unlock()
|
||||
return p.client.Send(message)
|
||||
}
|
||||
|
||||
func (p *SessionProxy) Close() error {
|
||||
p.sendMu.Lock()
|
||||
defer p.sendMu.Unlock()
|
||||
return p.client.CloseSend()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) ProxySession(ctx context.Context, sessionId string, receiver ProxySessionReceiver) (*SessionProxy, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("ProxySession").Inc()
|
||||
md := metadata.Pairs(
|
||||
"sessionId", sessionId,
|
||||
"remoteAddr", receiver.RemoteAddr(),
|
||||
"country", receiver.Country(),
|
||||
"userAgent", receiver.UserAgent(),
|
||||
)
|
||||
client, err := c.impl.ProxySession(metadata.NewOutgoingContext(ctx, md), grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxy := &SessionProxy{
|
||||
sessionId: sessionId,
|
||||
receiver: receiver,
|
||||
|
||||
client: client,
|
||||
}
|
||||
|
||||
go proxy.recvPump()
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
type grpcClientsList struct {
|
||||
clients []*GrpcClient
|
||||
entry *DnsMonitorEntry
|
||||
|
|
225
grpc_remote_client.go
Normal file
225
grpc_remote_client.go
Normal file
|
@ -0,0 +1,225 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2024 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
grpcRemoteClientMessageQueue = 16
|
||||
)
|
||||
|
||||
func getMD(md metadata.MD, key string) string {
|
||||
if values := md.Get(key); len(values) > 0 {
|
||||
return values[0]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// remoteGrpcClient is a remote client connecting from a GRPC proxy to a Hub.
|
||||
type remoteGrpcClient struct {
|
||||
hub *Hub
|
||||
client RpcSessions_ProxySessionServer
|
||||
|
||||
sessionId string
|
||||
remoteAddr string
|
||||
country string
|
||||
userAgent string
|
||||
|
||||
closeCtx context.Context
|
||||
closeFunc context.CancelCauseFunc
|
||||
|
||||
session atomic.Pointer[Session]
|
||||
messages chan WritableClientMessage
|
||||
}
|
||||
|
||||
func newRemoteGrpcClient(hub *Hub, request RpcSessions_ProxySessionServer) (*remoteGrpcClient, error) {
|
||||
md, found := metadata.FromIncomingContext(request.Context())
|
||||
if !found {
|
||||
return nil, errors.New("no metadata provided")
|
||||
}
|
||||
|
||||
closeCtx, closeFunc := context.WithCancelCause(context.Background())
|
||||
|
||||
result := &remoteGrpcClient{
|
||||
hub: hub,
|
||||
client: request,
|
||||
|
||||
sessionId: getMD(md, "sessionId"),
|
||||
remoteAddr: getMD(md, "remoteAddr"),
|
||||
country: getMD(md, "country"),
|
||||
userAgent: getMD(md, "userAgent"),
|
||||
|
||||
closeCtx: closeCtx,
|
||||
closeFunc: closeFunc,
|
||||
|
||||
messages: make(chan WritableClientMessage, grpcRemoteClientMessageQueue),
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) readPump() {
|
||||
var closeError error
|
||||
defer func() {
|
||||
c.closeFunc(closeError)
|
||||
c.hub.OnClosed(c)
|
||||
}()
|
||||
|
||||
for {
|
||||
msg, err := c.client.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
// Connection was closed locally.
|
||||
break
|
||||
}
|
||||
|
||||
if status.Code(err) != codes.Canceled {
|
||||
log.Printf("Error reading from remote client for session %s: %s", c.sessionId, err)
|
||||
closeError = err
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
c.hub.OnMessageReceived(c, msg.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) RemoteAddr() string {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) UserAgent() string {
|
||||
return c.userAgent
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) Country() string {
|
||||
return c.country
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) IsConnected() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) IsAuthenticated() bool {
|
||||
return c.GetSession() != nil
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) GetSession() Session {
|
||||
session := c.session.Load()
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return *session
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) SetSession(session Session) {
|
||||
if session == nil {
|
||||
c.session.Store(nil)
|
||||
} else {
|
||||
c.session.Store(&session)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) SendError(e *Error) bool {
|
||||
message := &ServerMessage{
|
||||
Type: "error",
|
||||
Error: e,
|
||||
}
|
||||
return c.SendMessage(message)
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) SendByeResponse(message *ClientMessage) bool {
|
||||
return c.SendByeResponseWithReason(message, "")
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) SendByeResponseWithReason(message *ClientMessage, reason string) bool {
|
||||
response := &ServerMessage{
|
||||
Type: "bye",
|
||||
}
|
||||
if message != nil {
|
||||
response.Id = message.Id
|
||||
}
|
||||
if reason != "" {
|
||||
if response.Bye == nil {
|
||||
response.Bye = &ByeServerMessage{}
|
||||
}
|
||||
response.Bye.Reason = reason
|
||||
}
|
||||
return c.SendMessage(response)
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) SendMessage(message WritableClientMessage) bool {
|
||||
if c.closeCtx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
select {
|
||||
case c.messages <- message:
|
||||
return true
|
||||
default:
|
||||
log.Printf("Message queue for remote client of session %s is full, not sending %+v", c.sessionId, message)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) Close() {
|
||||
c.closeFunc(nil)
|
||||
}
|
||||
|
||||
func (c *remoteGrpcClient) run() error {
|
||||
go c.readPump()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.closeCtx.Done():
|
||||
if err := context.Cause(c.closeCtx); err != context.Canceled {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
case msg := <-c.messages:
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
log.Printf("Error marshalling %+v for remote client for session %s: %s", msg, c.sessionId, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.client.Send(&ServerSessionMessage{
|
||||
Message: data,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error sending %+v to remote client for session %s: %w", msg, c.sessionId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -113,6 +113,20 @@ func (s *GrpcServer) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *GrpcServer) LookupResumeId(ctx context.Context, request *LookupResumeIdRequest) (*LookupResumeIdReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("LookupResumeId").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Lookup session for resume id %s", request.ResumeId)
|
||||
session := s.hub.GetSessionByResumeId(request.ResumeId)
|
||||
if session == nil {
|
||||
return nil, status.Error(codes.NotFound, "no such room session id")
|
||||
}
|
||||
|
||||
return &LookupResumeIdReply{
|
||||
SessionId: session.PublicId(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc()
|
||||
// TODO: Remove debug logging
|
||||
|
@ -216,3 +230,16 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCou
|
|||
Count: uint32(backend.Len()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) ProxySession(request RpcSessions_ProxySessionServer) error {
|
||||
statsGrpcServerCalls.WithLabelValues("ProxySession").Inc()
|
||||
client, err := newRemoteGrpcClient(s.hub, request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sid := s.hub.registerClient(client)
|
||||
defer s.hub.unregisterClient(sid)
|
||||
|
||||
return client.run()
|
||||
}
|
||||
|
|
|
@ -26,8 +26,18 @@ option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling"
|
|||
package signaling;
|
||||
|
||||
service RpcSessions {
|
||||
rpc LookupResumeId(LookupResumeIdRequest) returns (LookupResumeIdReply) {}
|
||||
rpc LookupSessionId(LookupSessionIdRequest) returns (LookupSessionIdReply) {}
|
||||
rpc IsSessionInCall(IsSessionInCallRequest) returns (IsSessionInCallReply) {}
|
||||
rpc ProxySession(stream ClientSessionMessage) returns (stream ServerSessionMessage) {}
|
||||
}
|
||||
|
||||
message LookupResumeIdRequest {
|
||||
string resumeId = 1;
|
||||
}
|
||||
|
||||
message LookupResumeIdReply {
|
||||
string sessionId = 1;
|
||||
}
|
||||
|
||||
message LookupSessionIdRequest {
|
||||
|
@ -49,3 +59,11 @@ message IsSessionInCallRequest {
|
|||
message IsSessionInCallReply {
|
||||
bool inCall = 1;
|
||||
}
|
||||
|
||||
message ClientSessionMessage {
|
||||
bytes message = 1;
|
||||
}
|
||||
|
||||
message ServerSessionMessage {
|
||||
bytes message = 1;
|
||||
}
|
||||
|
|
279
hub.go
279
hub.go
|
@ -135,7 +135,7 @@ type Hub struct {
|
|||
ru sync.RWMutex
|
||||
|
||||
sid atomic.Uint64
|
||||
clients map[uint64]*Client
|
||||
clients map[uint64]HandlerClient
|
||||
sessions map[uint64]Session
|
||||
rooms map[string]*Room
|
||||
|
||||
|
@ -153,8 +153,9 @@ type Hub struct {
|
|||
|
||||
expiredSessions map[Session]time.Time
|
||||
anonymousSessions map[*ClientSession]time.Time
|
||||
expectHelloClients map[*Client]time.Time
|
||||
expectHelloClients map[HandlerClient]time.Time
|
||||
dialoutSessions map[*ClientSession]bool
|
||||
remoteSessions map[*RemoteSession]bool
|
||||
|
||||
backendTimeout time.Duration
|
||||
backend *BackendClient
|
||||
|
@ -324,7 +325,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
|
|||
roomInCall: make(chan *BackendServerRoomRequest),
|
||||
roomParticipants: make(chan *BackendServerRoomRequest),
|
||||
|
||||
clients: make(map[uint64]*Client),
|
||||
clients: make(map[uint64]HandlerClient),
|
||||
sessions: make(map[uint64]Session),
|
||||
rooms: make(map[string]*Room),
|
||||
|
||||
|
@ -341,8 +342,9 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
|
|||
|
||||
expiredSessions: make(map[Session]time.Time),
|
||||
anonymousSessions: make(map[*ClientSession]time.Time),
|
||||
expectHelloClients: make(map[*Client]time.Time),
|
||||
expectHelloClients: make(map[HandlerClient]time.Time),
|
||||
dialoutSessions: make(map[*ClientSession]bool),
|
||||
remoteSessions: make(map[*RemoteSession]bool),
|
||||
|
||||
backendTimeout: backendTimeout,
|
||||
backend: backend,
|
||||
|
@ -584,6 +586,22 @@ func (h *Hub) GetSessionByPublicId(sessionId string) Session {
|
|||
return session
|
||||
}
|
||||
|
||||
func (h *Hub) GetSessionByResumeId(resumeId string) Session {
|
||||
data := h.decodeSessionId(resumeId, privateSessionName)
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session := h.sessions[data.Sid]
|
||||
if session != nil && session.PrivateId() != resumeId {
|
||||
// Session was created on different server.
|
||||
return nil
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func (h *Hub) GetDialoutSession(roomId string, backend *Backend) *ClientSession {
|
||||
url := backend.Url()
|
||||
|
||||
|
@ -690,15 +708,13 @@ func (h *Hub) startWaitAnonymousSessionRoomLocked(session *ClientSession) {
|
|||
h.anonymousSessions[session] = now.Add(anonmyousJoinRoomTimeout)
|
||||
}
|
||||
|
||||
func (h *Hub) startExpectHello(client *Client) {
|
||||
func (h *Hub) startExpectHello(client HandlerClient) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if !client.IsConnected() {
|
||||
return
|
||||
}
|
||||
|
||||
client.mu.Lock()
|
||||
defer client.mu.Unlock()
|
||||
if client.IsAuthenticated() {
|
||||
return
|
||||
}
|
||||
|
@ -708,15 +724,39 @@ func (h *Hub) startExpectHello(client *Client) {
|
|||
h.expectHelloClients[client] = now.Add(initialHelloTimeout)
|
||||
}
|
||||
|
||||
func (h *Hub) processNewClient(client *Client) {
|
||||
func (h *Hub) processNewClient(client HandlerClient) {
|
||||
h.startExpectHello(client)
|
||||
h.sendWelcome(client)
|
||||
}
|
||||
|
||||
func (h *Hub) sendWelcome(client *Client) {
|
||||
func (h *Hub) sendWelcome(client HandlerClient) {
|
||||
client.SendMessage(h.getWelcomeMessage())
|
||||
}
|
||||
|
||||
func (h *Hub) registerClient(client HandlerClient) uint64 {
|
||||
sid := h.sid.Add(1)
|
||||
for sid == 0 {
|
||||
sid = h.sid.Add(1)
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.clients[sid] = client
|
||||
return sid
|
||||
}
|
||||
|
||||
func (h *Hub) unregisterClient(sid uint64) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.clients, sid)
|
||||
}
|
||||
|
||||
func (h *Hub) unregisterRemoteSession(session *RemoteSession) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
delete(h.remoteSessions, session)
|
||||
}
|
||||
|
||||
func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {
|
||||
sid := h.sid.Add(1)
|
||||
for sid == 0 {
|
||||
|
@ -730,17 +770,24 @@ func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {
|
|||
return sessionIdData
|
||||
}
|
||||
|
||||
func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *Backend, auth *BackendClientResponse) {
|
||||
if !client.IsConnected() {
|
||||
func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *Backend, auth *BackendClientResponse) {
|
||||
if !c.IsConnected() {
|
||||
// Client disconnected while waiting for "hello" response.
|
||||
return
|
||||
}
|
||||
|
||||
if auth.Type == "error" {
|
||||
client.SendMessage(message.NewErrorServerMessage(auth.Error))
|
||||
c.SendMessage(message.NewErrorServerMessage(auth.Error))
|
||||
return
|
||||
} else if auth.Type != "auth" {
|
||||
client.SendMessage(message.NewErrorServerMessage(UserAuthFailed))
|
||||
c.SendMessage(message.NewErrorServerMessage(UserAuthFailed))
|
||||
return
|
||||
}
|
||||
|
||||
client, ok := c.(*Client)
|
||||
if !ok {
|
||||
log.Printf("Can't register non-client %T", c)
|
||||
client.SendMessage(message.NewWrappedErrorServerMessage(errors.New("can't register non-client")))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -844,7 +891,7 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
|
|||
h.sendHelloResponse(session, message)
|
||||
}
|
||||
|
||||
func (h *Hub) processUnregister(client *Client) *ClientSession {
|
||||
func (h *Hub) processUnregister(client HandlerClient) Session {
|
||||
session := client.GetSession()
|
||||
|
||||
h.mu.Lock()
|
||||
|
@ -857,14 +904,18 @@ func (h *Hub) processUnregister(client *Client) *ClientSession {
|
|||
h.mu.Unlock()
|
||||
if session != nil {
|
||||
log.Printf("Unregister %s (private=%s)", session.PublicId(), session.PrivateId())
|
||||
session.ClearClient(client)
|
||||
if c, ok := client.(*Client); ok {
|
||||
if cs, ok := session.(*ClientSession); ok {
|
||||
cs.ClearClient(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client.Close()
|
||||
return session
|
||||
}
|
||||
|
||||
func (h *Hub) processMessage(client *Client, data []byte) {
|
||||
func (h *Hub) processMessage(client HandlerClient, data []byte) {
|
||||
var message ClientMessage
|
||||
if err := message.UnmarshalJSON(data); err != nil {
|
||||
if session := client.GetSession(); session != nil {
|
||||
|
@ -944,12 +995,97 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage)
|
|||
return session.SendMessage(response)
|
||||
}
|
||||
|
||||
func (h *Hub) processHello(client *Client, message *ClientMessage) {
|
||||
type remoteClientInfo struct {
|
||||
client *GrpcClient
|
||||
response *LookupResumeIdReply
|
||||
}
|
||||
|
||||
func (h *Hub) tryProxyResume(c HandlerClient, resumeId string, message *ClientMessage) bool {
|
||||
client, ok := c.(*Client)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
var clients []*GrpcClient
|
||||
if h.rpcClients != nil {
|
||||
clients = h.rpcClients.GetClients()
|
||||
}
|
||||
if len(clients) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer rpcCancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ctx, cancel := context.WithCancel(rpcCtx)
|
||||
defer cancel()
|
||||
|
||||
var remoteClient atomic.Pointer[remoteClientInfo]
|
||||
for _, c := range clients {
|
||||
wg.Add(1)
|
||||
go func(client *GrpcClient) {
|
||||
defer wg.Done()
|
||||
|
||||
if client.IsSelf() {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := client.LookupResumeId(ctx, resumeId)
|
||||
if err != nil {
|
||||
log.Printf("Could not lookup resume id %s on %s: %s", resumeId, client.Target(), err)
|
||||
return
|
||||
}
|
||||
|
||||
cancel()
|
||||
remoteClient.CompareAndSwap(nil, &remoteClientInfo{
|
||||
client: client,
|
||||
response: response,
|
||||
})
|
||||
}(c)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if !client.IsConnected() {
|
||||
// Client disconnected while checking message.
|
||||
return false
|
||||
}
|
||||
|
||||
info := remoteClient.Load()
|
||||
if info == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
rs, err := NewRemoteSession(h, client, info.client, info.response.SessionId)
|
||||
if err != nil {
|
||||
log.Printf("Could not create remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := rs.Start(message); err != nil {
|
||||
rs.Close()
|
||||
log.Printf("Could not start remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Printf("Proxy session %s to %s", info.response.SessionId, info.client.Target())
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.remoteSessions[rs] = true
|
||||
delete(h.expectHelloClients, client)
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
|
||||
resumeId := message.Hello.ResumeId
|
||||
if resumeId != "" {
|
||||
data := h.decodeSessionId(resumeId, privateSessionName)
|
||||
if data == nil {
|
||||
statsHubSessionResumeFailed.Inc()
|
||||
if h.tryProxyResume(client, resumeId, message) {
|
||||
return
|
||||
}
|
||||
|
||||
client.SendMessage(message.NewErrorServerMessage(NoSuchSession))
|
||||
return
|
||||
}
|
||||
|
@ -959,6 +1095,10 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) {
|
|||
if !found || resumeId != session.PrivateId() {
|
||||
h.mu.Unlock()
|
||||
statsHubSessionResumeFailed.Inc()
|
||||
if h.tryProxyResume(client, resumeId, message) {
|
||||
return
|
||||
}
|
||||
|
||||
client.SendMessage(message.NewErrorServerMessage(NoSuchSession))
|
||||
return
|
||||
}
|
||||
|
@ -1013,7 +1153,7 @@ func (h *Hub) processHello(client *Client, message *ClientMessage) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
|
||||
func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
|
||||
url := message.Hello.Auth.parsedUrl
|
||||
backend := h.backend.GetBackend(url)
|
||||
if backend == nil {
|
||||
|
@ -1035,7 +1175,7 @@ func (h *Hub) processHelloV1(client *Client, message *ClientMessage) (*Backend,
|
|||
return backend, &auth, nil
|
||||
}
|
||||
|
||||
func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
|
||||
func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
|
||||
url := message.Hello.Auth.parsedUrl
|
||||
backend := h.backend.GetBackend(url)
|
||||
if backend == nil {
|
||||
|
@ -1141,11 +1281,11 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend,
|
|||
return backend, auth, nil
|
||||
}
|
||||
|
||||
func (h *Hub) processHelloClient(client *Client, message *ClientMessage) {
|
||||
func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) {
|
||||
// Make sure the client must send another "hello" in case of errors.
|
||||
defer h.startExpectHello(client)
|
||||
|
||||
var authFunc func(*Client, *ClientMessage) (*Backend, *BackendClientResponse, error)
|
||||
var authFunc func(HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error)
|
||||
switch message.Hello.Version {
|
||||
case HelloVersionV1:
|
||||
// Auth information contains a ticket that must be validated against the
|
||||
|
@ -1172,7 +1312,7 @@ func (h *Hub) processHelloClient(client *Client, message *ClientMessage) {
|
|||
h.processRegister(client, message, backend, auth)
|
||||
}
|
||||
|
||||
func (h *Hub) processHelloInternal(client *Client, message *ClientMessage) {
|
||||
func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage) {
|
||||
defer h.startExpectHello(client)
|
||||
if len(h.internalClientsSecret) == 0 {
|
||||
client.SendMessage(message.NewErrorServerMessage(InvalidClientType))
|
||||
|
@ -1261,8 +1401,12 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo
|
|||
return session.SendMessage(response)
|
||||
}
|
||||
|
||||
func (h *Hub) processRoom(client *Client, message *ClientMessage) {
|
||||
session := client.GetSession()
|
||||
func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
|
||||
session, ok := client.GetSession().(*ClientSession)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
roomId := message.Room.RoomId
|
||||
if roomId == "" {
|
||||
if session == nil {
|
||||
|
@ -1281,29 +1425,34 @@ func (h *Hub) processRoom(client *Client, message *ClientMessage) {
|
|||
return
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) {
|
||||
// Session already is in that room, no action needed.
|
||||
roomSessionId := message.Room.SessionId
|
||||
if roomSessionId == "" {
|
||||
// TODO(jojo): Better make the session id required in the request.
|
||||
log.Printf("User did not send a room session id, assuming session %s", session.PublicId())
|
||||
roomSessionId = session.PublicId()
|
||||
}
|
||||
if session == nil {
|
||||
session.SendMessage(message.NewErrorServerMessage(
|
||||
NewError("not_authenticated", "Need to authenticate before joining rooms."),
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
if err := session.UpdateRoomSessionId(roomSessionId); err != nil {
|
||||
log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err)
|
||||
}
|
||||
session.SendMessage(message.NewErrorServerMessage(
|
||||
NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{
|
||||
Room: &RoomServerMessage{
|
||||
RoomId: room.id,
|
||||
Properties: room.properties,
|
||||
},
|
||||
}),
|
||||
))
|
||||
return
|
||||
if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) {
|
||||
// Session already is in that room, no action needed.
|
||||
roomSessionId := message.Room.SessionId
|
||||
if roomSessionId == "" {
|
||||
// TODO(jojo): Better make the session id required in the request.
|
||||
log.Printf("User did not send a room session id, assuming session %s", session.PublicId())
|
||||
roomSessionId = session.PublicId()
|
||||
}
|
||||
|
||||
if err := session.UpdateRoomSessionId(roomSessionId); err != nil {
|
||||
log.Printf("Error updating room session id for session %s: %s", session.PublicId(), err)
|
||||
}
|
||||
session.SendMessage(message.NewErrorServerMessage(
|
||||
NewErrorDetail("already_joined", "Already joined this room.", &RoomErrorDetails{
|
||||
Room: &RoomServerMessage{
|
||||
RoomId: room.id,
|
||||
Properties: room.properties,
|
||||
},
|
||||
}),
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var room BackendClientResponse
|
||||
|
@ -1430,14 +1579,14 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
|
|||
r.AddSession(session, room.Room.Session)
|
||||
}
|
||||
|
||||
func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
|
||||
msg := message.Message
|
||||
session := client.GetSession()
|
||||
if session == nil {
|
||||
func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
|
||||
session, ok := client.GetSession().(*ClientSession)
|
||||
if session == nil || !ok {
|
||||
// Client is not connected yet.
|
||||
return
|
||||
}
|
||||
|
||||
msg := message.Message
|
||||
var recipient *ClientSession
|
||||
var subject string
|
||||
var clientData *MessageClientMessageData
|
||||
|
@ -1484,10 +1633,10 @@ func (h *Hub) processMessageMsg(client *Client, message *ClientMessage) {
|
|||
// User is stopping to share his screen. Firefox doesn't properly clean
|
||||
// up the peer connections in all cases, so make sure to stop publishing
|
||||
// in the MCU.
|
||||
go func(c *Client) {
|
||||
go func(c HandlerClient) {
|
||||
time.Sleep(cleanupScreenPublisherDelay)
|
||||
session := c.GetSession()
|
||||
if session == nil {
|
||||
session, ok := c.GetSession().(*ClientSession)
|
||||
if session == nil || !ok {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1700,7 +1849,7 @@ func isAllowedToControl(session Session) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
|
||||
func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) {
|
||||
msg := message.Control
|
||||
session := client.GetSession()
|
||||
if session == nil {
|
||||
|
@ -1813,10 +1962,10 @@ func (h *Hub) processControlMsg(client *Client, message *ClientMessage) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Hub) processInternalMsg(client *Client, message *ClientMessage) {
|
||||
func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) {
|
||||
msg := message.Internal
|
||||
session := client.GetSession()
|
||||
if session == nil {
|
||||
session, ok := client.GetSession().(*ClientSession)
|
||||
if session == nil || !ok {
|
||||
// Client is not connected yet.
|
||||
return
|
||||
} else if session.ClientType() != HelloClientTypeInternal {
|
||||
|
@ -2030,7 +2179,7 @@ func isAllowedToUpdateTransientData(session Session) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) {
|
||||
func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage) {
|
||||
msg := message.TransientData
|
||||
session := client.GetSession()
|
||||
if session == nil {
|
||||
|
@ -2070,17 +2219,17 @@ func (h *Hub) processTransientMsg(client *Client, message *ClientMessage) {
|
|||
}
|
||||
}
|
||||
|
||||
func sendNotAllowed(session *ClientSession, message *ClientMessage, reason string) {
|
||||
func sendNotAllowed(session Session, message *ClientMessage, reason string) {
|
||||
response := message.NewErrorServerMessage(NewError("not_allowed", reason))
|
||||
session.SendMessage(response)
|
||||
}
|
||||
|
||||
func sendMcuClientNotFound(session *ClientSession, message *ClientMessage) {
|
||||
func sendMcuClientNotFound(session Session, message *ClientMessage) {
|
||||
response := message.NewErrorServerMessage(NewError("client_not_found", "No MCU client found to send message to."))
|
||||
session.SendMessage(response)
|
||||
}
|
||||
|
||||
func sendMcuProcessingFailed(session *ClientSession, message *ClientMessage) {
|
||||
func sendMcuProcessingFailed(session Session, message *ClientMessage) {
|
||||
response := message.NewErrorServerMessage(NewError("processing_failed", "Processing of the message failed, please check server logs."))
|
||||
session.SendMessage(response)
|
||||
}
|
||||
|
@ -2295,7 +2444,7 @@ func (h *Hub) sendMcuMessageResponse(session *ClientSession, mcuClient McuClient
|
|||
session.SendMessage(response_message)
|
||||
}
|
||||
|
||||
func (h *Hub) processByeMsg(client *Client, message *ClientMessage) {
|
||||
func (h *Hub) processByeMsg(client HandlerClient, message *ClientMessage) {
|
||||
client.SendByeResponse(message)
|
||||
if session := h.processUnregister(client); session != nil {
|
||||
session.Close()
|
||||
|
@ -2412,7 +2561,7 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
|
|||
}(h)
|
||||
}
|
||||
|
||||
func (h *Hub) OnLookupCountry(client *Client) string {
|
||||
func (h *Hub) OnLookupCountry(client HandlerClient) string {
|
||||
ip := net.ParseIP(client.RemoteAddr())
|
||||
if ip == nil {
|
||||
return noCountry
|
||||
|
@ -2444,14 +2593,14 @@ func (h *Hub) OnLookupCountry(client *Client) string {
|
|||
return country
|
||||
}
|
||||
|
||||
func (h *Hub) OnClosed(client *Client) {
|
||||
func (h *Hub) OnClosed(client HandlerClient) {
|
||||
h.processUnregister(client)
|
||||
}
|
||||
|
||||
func (h *Hub) OnMessageReceived(client *Client, data []byte) {
|
||||
func (h *Hub) OnMessageReceived(client HandlerClient, data []byte) {
|
||||
h.processMessage(client, data)
|
||||
}
|
||||
|
||||
func (h *Hub) OnRTTReceived(client *Client, rtt time.Duration) {
|
||||
func (h *Hub) OnRTTReceived(client HandlerClient, rtt time.Duration) {
|
||||
// Ignore
|
||||
}
|
||||
|
|
311
hub_test.go
311
hub_test.go
|
@ -274,13 +274,19 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
|
|||
h.mu.Lock()
|
||||
clients := len(h.clients)
|
||||
sessions := len(h.sessions)
|
||||
remoteSessions := len(h.remoteSessions)
|
||||
h.mu.Unlock()
|
||||
h.ru.Lock()
|
||||
rooms := len(h.rooms)
|
||||
h.ru.Unlock()
|
||||
readActive := h.readPumpActive.Load()
|
||||
writeActive := h.writePumpActive.Load()
|
||||
if clients == 0 && rooms == 0 && sessions == 0 && readActive == 0 && writeActive == 0 {
|
||||
if clients == 0 &&
|
||||
rooms == 0 &&
|
||||
sessions == 0 &&
|
||||
remoteSessions == 0 &&
|
||||
readActive == 0 &&
|
||||
writeActive == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -289,7 +295,7 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
|
|||
h.mu.Lock()
|
||||
h.ru.Lock()
|
||||
dumpGoroutines("", os.Stderr)
|
||||
t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, readActive, writeActive, ctx.Err())
|
||||
t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / remoteSessions %v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, h.remoteSessions, readActive, writeActive, ctx.Err())
|
||||
h.ru.Unlock()
|
||||
h.mu.Unlock()
|
||||
return
|
||||
|
@ -1892,6 +1898,307 @@ func TestClientHelloResumeAndJoin(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func runGrpcProxyTest(t *testing.T, f func(hub1, hub2 *Hub, server1, server2 *httptest.Server)) {
|
||||
t.Helper()
|
||||
|
||||
var hub1 *Hub
|
||||
var hub2 *Hub
|
||||
var server1 *httptest.Server
|
||||
var server2 *httptest.Server
|
||||
var router1 *mux.Router
|
||||
var router2 *mux.Router
|
||||
hub1, hub2, router1, router2, server1, server2 = CreateClusteredHubsForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
|
||||
// Make sure all backends use the same server
|
||||
if server1 == nil {
|
||||
server1 = server
|
||||
} else {
|
||||
server = server1
|
||||
}
|
||||
|
||||
config, err := getTestConfig(server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.RemoveOption("backend", "allowed")
|
||||
config.RemoveOption("backend", "secret")
|
||||
config.AddOption("backend", "backends", "backend1")
|
||||
|
||||
config.AddOption("backend1", "url", server.URL)
|
||||
config.AddOption("backend1", "secret", string(testBackendSecret))
|
||||
config.AddOption("backend1", "sessionlimit", "1")
|
||||
return config, nil
|
||||
})
|
||||
|
||||
registerBackendHandlerUrl(t, router1, "/")
|
||||
registerBackendHandlerUrl(t, router2, "/")
|
||||
|
||||
f(hub1, hub2, server1, server2)
|
||||
}
|
||||
|
||||
func TestClientHelloResumeProxy(t *testing.T) {
|
||||
ensureNoGoroutinesLeak(t, func(t *testing.T) {
|
||||
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
|
||||
client1 := NewTestClient(t, server1, hub1)
|
||||
defer client1.CloseWithBye()
|
||||
|
||||
if err := client1.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
hello, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if hello.Hello.UserId != testDefaultUserId {
|
||||
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
|
||||
}
|
||||
if hello.Hello.SessionId == "" {
|
||||
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||
}
|
||||
if hello.Hello.ResumeId == "" {
|
||||
t.Errorf("Expected resume id, got %+v", hello.Hello)
|
||||
}
|
||||
}
|
||||
|
||||
client1.Close()
|
||||
if err := client1.WaitForClientRemoved(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
client2 := NewTestClient(t, server2, hub2)
|
||||
defer client2.CloseWithBye()
|
||||
|
||||
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if hello2.Hello.UserId != testDefaultUserId {
|
||||
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
|
||||
}
|
||||
if hello2.Hello.SessionId != hello.Hello.SessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
|
||||
}
|
||||
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
|
||||
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
|
||||
}
|
||||
}
|
||||
|
||||
// Join room by id.
|
||||
roomId := "test-room"
|
||||
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if room.Room.RoomId != roomId {
|
||||
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
|
||||
}
|
||||
|
||||
// We will receive a "joined" event.
|
||||
if err := client2.RunUntilJoined(ctx, hello.Hello); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if room := hub1.getRoom(roomId); room == nil {
|
||||
t.Fatalf("Could not find room %s", roomId)
|
||||
}
|
||||
if room := hub2.getRoom(roomId); room != nil {
|
||||
t.Fatalf("Should not have gotten room %s, got %+v", roomId, room)
|
||||
}
|
||||
|
||||
users := []map[string]interface{}{
|
||||
{
|
||||
"sessionId": "the-session-id",
|
||||
"inCall": 1,
|
||||
},
|
||||
}
|
||||
room := hub1.getRoom(roomId)
|
||||
if room == nil {
|
||||
t.Fatalf("Could not find room %s", roomId)
|
||||
}
|
||||
room.PublishUsersInCallChanged(users, users)
|
||||
if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientHelloResumeProxy_Takeover(t *testing.T) {
|
||||
ensureNoGoroutinesLeak(t, func(t *testing.T) {
|
||||
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
|
||||
client1 := NewTestClient(t, server1, hub1)
|
||||
defer client1.CloseWithBye()
|
||||
|
||||
if err := client1.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
hello, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if hello.Hello.UserId != testDefaultUserId {
|
||||
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
|
||||
}
|
||||
if hello.Hello.SessionId == "" {
|
||||
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||
}
|
||||
if hello.Hello.ResumeId == "" {
|
||||
t.Errorf("Expected resume id, got %+v", hello.Hello)
|
||||
}
|
||||
}
|
||||
|
||||
client2 := NewTestClient(t, server2, hub2)
|
||||
defer client2.CloseWithBye()
|
||||
|
||||
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if hello2.Hello.UserId != testDefaultUserId {
|
||||
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
|
||||
}
|
||||
if hello2.Hello.SessionId != hello.Hello.SessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
|
||||
}
|
||||
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
|
||||
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
|
||||
}
|
||||
}
|
||||
|
||||
// The first client got disconnected with a reason in a "Bye" message.
|
||||
if msg, err := client1.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if msg.Type != "bye" || msg.Bye == nil {
|
||||
t.Errorf("Expected bye message, got %+v", msg)
|
||||
} else if msg.Bye.Reason != "session_resumed" {
|
||||
t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
if msg, err := client1.RunUntilMessage(ctx); err == nil {
|
||||
t.Errorf("Expected error but received %+v", msg)
|
||||
} else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
|
||||
t.Errorf("Expected close error but received %+v", err)
|
||||
}
|
||||
|
||||
client3 := NewTestClient(t, server1, hub1)
|
||||
defer client3.CloseWithBye()
|
||||
|
||||
if err := client3.SendHelloResume(hello.Hello.ResumeId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hello3, err := client3.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if hello3.Hello.UserId != testDefaultUserId {
|
||||
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
|
||||
}
|
||||
if hello3.Hello.SessionId != hello.Hello.SessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
|
||||
}
|
||||
if hello3.Hello.ResumeId != hello.Hello.ResumeId {
|
||||
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
|
||||
}
|
||||
}
|
||||
|
||||
// The second client got disconnected with a reason in a "Bye" message.
|
||||
if msg, err := client2.RunUntilMessage(ctx); err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if msg.Type != "bye" || msg.Bye == nil {
|
||||
t.Errorf("Expected bye message, got %+v", msg)
|
||||
} else if msg.Bye.Reason != "session_resumed" {
|
||||
t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
if msg, err := client2.RunUntilMessage(ctx); err == nil {
|
||||
t.Errorf("Expected error but received %+v", msg)
|
||||
} else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
|
||||
t.Errorf("Expected close error but received %+v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientHelloResumeProxy_Disconnect(t *testing.T) {
|
||||
ensureNoGoroutinesLeak(t, func(t *testing.T) {
|
||||
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
|
||||
client1 := NewTestClient(t, server1, hub1)
|
||||
defer client1.CloseWithBye()
|
||||
|
||||
if err := client1.SendHello(testDefaultUserId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
|
||||
hello, err := client1.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
if hello.Hello.UserId != testDefaultUserId {
|
||||
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
|
||||
}
|
||||
if hello.Hello.SessionId == "" {
|
||||
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||
}
|
||||
if hello.Hello.ResumeId == "" {
|
||||
t.Errorf("Expected resume id, got %+v", hello.Hello)
|
||||
}
|
||||
}
|
||||
|
||||
client1.Close()
|
||||
if err := client1.WaitForClientRemoved(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
client2 := NewTestClient(t, server2, hub2)
|
||||
defer client2.CloseWithBye()
|
||||
|
||||
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hello2, err := client2.RunUntilHello(ctx)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else {
|
||||
if hello2.Hello.UserId != testDefaultUserId {
|
||||
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
|
||||
}
|
||||
if hello2.Hello.SessionId != hello.Hello.SessionId {
|
||||
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
|
||||
}
|
||||
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
|
||||
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate unclean shutdown of second instance.
|
||||
hub2.rpcServer.conn.Stop()
|
||||
|
||||
if err := client2.WaitForClientRemoved(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientHelloClient(t *testing.T) {
|
||||
hub, _, _, server := CreateHubForTest(t)
|
||||
|
||||
|
|
|
@ -53,18 +53,18 @@ func (c *ProxyClient) SetSession(session *ProxySession) {
|
|||
c.session.Store(session)
|
||||
}
|
||||
|
||||
func (c *ProxyClient) OnClosed(client *signaling.Client) {
|
||||
func (c *ProxyClient) OnClosed(client signaling.HandlerClient) {
|
||||
if session := c.GetSession(); session != nil {
|
||||
session.MarkUsed()
|
||||
}
|
||||
c.proxy.clientClosed(&c.Client)
|
||||
}
|
||||
|
||||
func (c *ProxyClient) OnMessageReceived(client *signaling.Client, data []byte) {
|
||||
func (c *ProxyClient) OnMessageReceived(client signaling.HandlerClient, data []byte) {
|
||||
c.proxy.processMessage(c, data)
|
||||
}
|
||||
|
||||
func (c *ProxyClient) OnRTTReceived(client *signaling.Client, rtt time.Duration) {
|
||||
func (c *ProxyClient) OnRTTReceived(client signaling.HandlerClient, rtt time.Duration) {
|
||||
if session := c.GetSession(); session != nil {
|
||||
session.MarkUsed()
|
||||
}
|
||||
|
|
152
remotesession.go
Normal file
152
remotesession.go
Normal file
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2024 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package signaling
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RemoteSession struct {
|
||||
hub *Hub
|
||||
client *Client
|
||||
remoteClient *GrpcClient
|
||||
sessionId string
|
||||
|
||||
proxy atomic.Pointer[SessionProxy]
|
||||
}
|
||||
|
||||
func NewRemoteSession(hub *Hub, client *Client, remoteClient *GrpcClient, sessionId string) (*RemoteSession, error) {
|
||||
remoteSession := &RemoteSession{
|
||||
hub: hub,
|
||||
client: client,
|
||||
remoteClient: remoteClient,
|
||||
sessionId: sessionId,
|
||||
}
|
||||
|
||||
client.SetSessionId(sessionId)
|
||||
client.SetHandler(remoteSession)
|
||||
|
||||
proxy, err := remoteClient.ProxySession(context.Background(), sessionId, remoteSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
remoteSession.proxy.Store(proxy)
|
||||
|
||||
return remoteSession, nil
|
||||
}
|
||||
|
||||
func (s *RemoteSession) Country() string {
|
||||
return s.client.Country()
|
||||
}
|
||||
|
||||
func (s *RemoteSession) RemoteAddr() string {
|
||||
return s.client.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *RemoteSession) UserAgent() string {
|
||||
return s.client.UserAgent()
|
||||
}
|
||||
|
||||
func (s *RemoteSession) IsConnected() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *RemoteSession) Start(message *ClientMessage) error {
|
||||
return s.sendMessage(message)
|
||||
}
|
||||
|
||||
func (s *RemoteSession) OnProxyMessage(msg *ServerSessionMessage) error {
|
||||
var message *ServerMessage
|
||||
if err := json.Unmarshal(msg.Message, &message); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.client.SendMessage(message) {
|
||||
return fmt.Errorf("could not send message to client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RemoteSession) OnProxyClose(err error) {
|
||||
if err != nil {
|
||||
log.Printf("Proxy connection for session %s to %s was closed with error: %s", s.sessionId, s.remoteClient.Target(), err)
|
||||
}
|
||||
s.Close()
|
||||
}
|
||||
|
||||
func (s *RemoteSession) SendMessage(message WritableClientMessage) bool {
|
||||
return s.sendMessage(message) == nil
|
||||
}
|
||||
|
||||
func (s *RemoteSession) sendProxyMessage(message []byte) error {
|
||||
proxy := s.proxy.Load()
|
||||
if proxy == nil {
|
||||
return errors.New("proxy already closed")
|
||||
}
|
||||
|
||||
msg := &ClientSessionMessage{
|
||||
Message: message,
|
||||
}
|
||||
return proxy.Send(msg)
|
||||
}
|
||||
|
||||
func (s *RemoteSession) sendMessage(message interface{}) error {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.sendProxyMessage(data)
|
||||
}
|
||||
|
||||
func (s *RemoteSession) Close() {
|
||||
if proxy := s.proxy.Swap(nil); proxy != nil {
|
||||
proxy.Close()
|
||||
}
|
||||
s.hub.unregisterRemoteSession(s)
|
||||
s.client.Close()
|
||||
}
|
||||
|
||||
func (s *RemoteSession) OnLookupCountry(client HandlerClient) string {
|
||||
return s.hub.OnLookupCountry(client)
|
||||
}
|
||||
|
||||
func (s *RemoteSession) OnClosed(client HandlerClient) {
|
||||
s.Close()
|
||||
}
|
||||
|
||||
func (s *RemoteSession) OnMessageReceived(client HandlerClient, message []byte) {
|
||||
if err := s.sendProxyMessage(message); err != nil {
|
||||
log.Printf("Error sending %s to the proxy for session %s: %s", string(message), s.sessionId, err)
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RemoteSession) OnRTTReceived(client HandlerClient, rtt time.Duration) {
|
||||
}
|
|
@ -86,6 +86,14 @@ func (s *DummySession) HasPermission(permission Permission) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (s *DummySession) SendError(e *Error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *DummySession) SendMessage(message *ServerMessage) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func checkSession(t *testing.T, sessions RoomSessions, sessionId string, roomSessionId string) Session {
|
||||
session := &DummySession{
|
||||
publicId: sessionId,
|
||||
|
|
|
@ -72,4 +72,7 @@ type Session interface {
|
|||
Close()
|
||||
|
||||
HasPermission(permission Permission) bool
|
||||
|
||||
SendError(e *Error) bool
|
||||
SendMessage(message *ServerMessage) bool
|
||||
}
|
||||
|
|
|
@ -311,12 +311,14 @@ func (c *TestClient) WaitForClientRemoved(ctx context.Context) error {
|
|||
for {
|
||||
found := false
|
||||
for _, client := range c.hub.clients {
|
||||
client.mu.Lock()
|
||||
conn := client.conn
|
||||
client.mu.Unlock()
|
||||
if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() {
|
||||
found = true
|
||||
break
|
||||
if cc, ok := client.(*Client); ok {
|
||||
cc.mu.Lock()
|
||||
conn := cc.conn
|
||||
cc.mu.Unlock()
|
||||
if conn != nil && conn.RemoteAddr().String() == c.localAddr.String() {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
|
@ -493,7 +495,7 @@ func (c *TestClient) SendHelloParams(url string, version string, clientType stri
|
|||
Hello: &HelloClientMessage{
|
||||
Version: version,
|
||||
Features: features,
|
||||
Auth: HelloClientMessageAuth{
|
||||
Auth: &HelloClientMessageAuth{
|
||||
Type: clientType,
|
||||
Url: url,
|
||||
Params: (*json.RawMessage)(&data),
|
||||
|
|
|
@ -51,7 +51,7 @@ type VirtualSession struct {
|
|||
options *AddSessionOptions
|
||||
}
|
||||
|
||||
func GetVirtualSessionId(session *ClientSession, sessionId string) string {
|
||||
func GetVirtualSessionId(session Session, sessionId string) string {
|
||||
return session.PublicId() + "|" + sessionId
|
||||
}
|
||||
|
||||
|
@ -163,7 +163,7 @@ func (s *VirtualSession) Close() {
|
|||
s.CloseWithFeedback(nil, nil)
|
||||
}
|
||||
|
||||
func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *ClientMessage) {
|
||||
func (s *VirtualSession) CloseWithFeedback(session Session, message *ClientMessage) {
|
||||
room := s.GetRoom()
|
||||
s.session.RemoveVirtualSession(s)
|
||||
removed := s.session.hub.removeSession(s)
|
||||
|
@ -173,7 +173,7 @@ func (s *VirtualSession) CloseWithFeedback(session *ClientSession, message *Clie
|
|||
s.session.events.UnregisterSessionListener(s.PublicId(), s.session.Backend(), s)
|
||||
}
|
||||
|
||||
func (s *VirtualSession) notifyBackendRemoved(room *Room, session *ClientSession, message *ClientMessage) {
|
||||
func (s *VirtualSession) notifyBackendRemoved(room *Room, session Session, message *ClientMessage) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.hub.backendTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
@ -321,3 +321,11 @@ func (s *VirtualSession) ProcessAsyncSessionMessage(message *AsyncMessage) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VirtualSession) SendError(e *Error) bool {
|
||||
return s.session.SendError(e)
|
||||
}
|
||||
|
||||
func (s *VirtualSession) SendMessage(message *ServerMessage) bool {
|
||||
return s.session.SendMessage(message)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue