Support resuming sessions that exist on a different Hub in the cluster.

This commit is contained in:
Joachim Bauch 2024-04-23 11:46:32 +02:00
parent 0c2cefa63a
commit 602452fa25
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
7 changed files with 966 additions and 2 deletions

View file

@ -26,6 +26,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/url"
@ -39,6 +40,7 @@ import (
"google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
status "google.golang.org/grpc/status"
)
@ -51,6 +53,8 @@ const (
)
var (
ErrNoSuchResumeId = fmt.Errorf("unknown resume id")
customResolverPrefix atomic.Uint64
)
@ -185,6 +189,26 @@ func (c *GrpcClient) GetServerId(ctx context.Context) (string, error) {
return response.GetServerId(), nil
}
func (c *GrpcClient) LookupResumeId(ctx context.Context, resumeId string) (*LookupResumeIdReply, error) {
statsGrpcClientCalls.WithLabelValues("LookupResumeId").Inc()
// TODO: Remove debug logging
log.Printf("Lookup resume id %s on %s", resumeId, c.Target())
response, err := c.impl.LookupResumeId(ctx, &LookupResumeIdRequest{
ResumeId: resumeId,
}, grpc.WaitForReady(true))
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return nil, ErrNoSuchResumeId
} else if err != nil {
return nil, err
}
if sessionId := response.GetSessionId(); sessionId == "" {
return nil, ErrNoSuchResumeId
}
return response, nil
}
func (c *GrpcClient) LookupSessionId(ctx context.Context, roomSessionId string, disconnectReason string) (string, error) {
statsGrpcClientCalls.WithLabelValues("LookupSessionId").Inc()
// TODO: Remove debug logging
@ -258,6 +282,86 @@ func (c *GrpcClient) GetSessionCount(ctx context.Context, u *url.URL) (uint32, e
return response.GetCount(), nil
}
type ProxySessionReceiver interface {
RemoteAddr() string
Country() string
UserAgent() string
OnProxyMessage(message *ServerSessionMessage) error
OnProxyClose(err error)
}
type SessionProxy struct {
sessionId string
receiver ProxySessionReceiver
sendMu sync.Mutex
client RpcSessions_ProxySessionClient
}
func (p *SessionProxy) recvPump() {
var closeError error
defer func() {
p.receiver.OnProxyClose(closeError)
if err := p.Close(); err != nil {
log.Printf("Error closing proxy for session %s: %s", p.sessionId, err)
}
}()
for {
msg, err := p.client.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
log.Printf("Error receiving message from proxy for session %s: %s", p.sessionId, err)
closeError = err
break
}
if err := p.receiver.OnProxyMessage(msg); err != nil {
log.Printf("Error processing message %+v from proxy for session %s: %s", msg, p.sessionId, err)
}
}
}
func (p *SessionProxy) Send(message *ClientSessionMessage) error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.client.Send(message)
}
func (p *SessionProxy) Close() error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.client.CloseSend()
}
func (c *GrpcClient) ProxySession(ctx context.Context, sessionId string, receiver ProxySessionReceiver) (*SessionProxy, error) {
statsGrpcClientCalls.WithLabelValues("ProxySession").Inc()
md := metadata.Pairs(
"sessionId", sessionId,
"remoteAddr", receiver.RemoteAddr(),
"country", receiver.Country(),
"userAgent", receiver.UserAgent(),
)
client, err := c.impl.ProxySession(metadata.NewOutgoingContext(ctx, md), grpc.WaitForReady(true))
if err != nil {
return nil, err
}
proxy := &SessionProxy{
sessionId: sessionId,
receiver: receiver,
client: client,
}
go proxy.recvPump()
return proxy, nil
}
type grpcClientsList struct {
clients []*GrpcClient
entry *DnsMonitorEntry

225
grpc_remote_client.go Normal file
View file

@ -0,0 +1,225 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"sync/atomic"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const (
grpcRemoteClientMessageQueue = 16
)
func getMD(md metadata.MD, key string) string {
if values := md.Get(key); len(values) > 0 {
return values[0]
}
return ""
}
// remoteGrpcClient is a remote client connecting from a GRPC proxy to a Hub.
type remoteGrpcClient struct {
hub *Hub
client RpcSessions_ProxySessionServer
sessionId string
remoteAddr string
country string
userAgent string
closeCtx context.Context
closeFunc context.CancelCauseFunc
session atomic.Pointer[Session]
messages chan WritableClientMessage
}
func newRemoteGrpcClient(hub *Hub, request RpcSessions_ProxySessionServer) (*remoteGrpcClient, error) {
md, found := metadata.FromIncomingContext(request.Context())
if !found {
return nil, errors.New("no metadata provided")
}
closeCtx, closeFunc := context.WithCancelCause(context.Background())
result := &remoteGrpcClient{
hub: hub,
client: request,
sessionId: getMD(md, "sessionId"),
remoteAddr: getMD(md, "remoteAddr"),
country: getMD(md, "country"),
userAgent: getMD(md, "userAgent"),
closeCtx: closeCtx,
closeFunc: closeFunc,
messages: make(chan WritableClientMessage, grpcRemoteClientMessageQueue),
}
return result, nil
}
func (c *remoteGrpcClient) readPump() {
var closeError error
defer func() {
c.closeFunc(closeError)
c.hub.OnClosed(c)
}()
for {
msg, err := c.client.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
// Connection was closed locally.
break
}
if status.Code(err) != codes.Canceled {
log.Printf("Error reading from remote client for session %s: %s", c.sessionId, err)
closeError = err
}
break
}
c.hub.OnMessageReceived(c, msg.Message)
}
}
func (c *remoteGrpcClient) RemoteAddr() string {
return c.remoteAddr
}
func (c *remoteGrpcClient) UserAgent() string {
return c.userAgent
}
func (c *remoteGrpcClient) Country() string {
return c.country
}
func (c *remoteGrpcClient) IsConnected() bool {
return true
}
func (c *remoteGrpcClient) IsAuthenticated() bool {
return c.GetSession() != nil
}
func (c *remoteGrpcClient) GetSession() Session {
session := c.session.Load()
if session == nil {
return nil
}
return *session
}
func (c *remoteGrpcClient) SetSession(session Session) {
if session == nil {
c.session.Store(nil)
} else {
c.session.Store(&session)
}
}
func (c *remoteGrpcClient) SendError(e *Error) bool {
message := &ServerMessage{
Type: "error",
Error: e,
}
return c.SendMessage(message)
}
func (c *remoteGrpcClient) SendByeResponse(message *ClientMessage) bool {
return c.SendByeResponseWithReason(message, "")
}
func (c *remoteGrpcClient) SendByeResponseWithReason(message *ClientMessage, reason string) bool {
response := &ServerMessage{
Type: "bye",
}
if message != nil {
response.Id = message.Id
}
if reason != "" {
if response.Bye == nil {
response.Bye = &ByeServerMessage{}
}
response.Bye.Reason = reason
}
return c.SendMessage(response)
}
func (c *remoteGrpcClient) SendMessage(message WritableClientMessage) bool {
if c.closeCtx.Err() != nil {
return false
}
select {
case c.messages <- message:
return true
default:
log.Printf("Message queue for remote client of session %s is full, not sending %+v", c.sessionId, message)
return false
}
}
func (c *remoteGrpcClient) Close() {
c.closeFunc(nil)
}
func (c *remoteGrpcClient) run() error {
go c.readPump()
for {
select {
case <-c.closeCtx.Done():
if err := context.Cause(c.closeCtx); err != context.Canceled {
return err
}
return nil
case msg := <-c.messages:
data, err := json.Marshal(msg)
if err != nil {
log.Printf("Error marshalling %+v for remote client for session %s: %s", msg, c.sessionId, err)
continue
}
if err := c.client.Send(&ServerSessionMessage{
Message: data,
}); err != nil {
return fmt.Errorf("error sending %+v to remote client for session %s: %w", msg, c.sessionId, err)
}
}
}
}

View file

@ -113,6 +113,20 @@ func (s *GrpcServer) Close() {
}
}
func (s *GrpcServer) LookupResumeId(ctx context.Context, request *LookupResumeIdRequest) (*LookupResumeIdReply, error) {
statsGrpcServerCalls.WithLabelValues("LookupResumeId").Inc()
// TODO: Remove debug logging
log.Printf("Lookup session for resume id %s", request.ResumeId)
session := s.hub.GetSessionByResumeId(request.ResumeId)
if session == nil {
return nil, status.Error(codes.NotFound, "no such room session id")
}
return &LookupResumeIdReply{
SessionId: session.PublicId(),
}, nil
}
func (s *GrpcServer) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) {
statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc()
// TODO: Remove debug logging
@ -216,3 +230,16 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCou
Count: uint32(backend.Len()),
}, nil
}
func (s *GrpcServer) ProxySession(request RpcSessions_ProxySessionServer) error {
statsGrpcServerCalls.WithLabelValues("ProxySession").Inc()
client, err := newRemoteGrpcClient(s.hub, request)
if err != nil {
return err
}
sid := s.hub.registerClient(client)
defer s.hub.unregisterClient(sid)
return client.run()
}

View file

@ -26,8 +26,18 @@ option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling"
package signaling;
service RpcSessions {
rpc LookupResumeId(LookupResumeIdRequest) returns (LookupResumeIdReply) {}
rpc LookupSessionId(LookupSessionIdRequest) returns (LookupSessionIdReply) {}
rpc IsSessionInCall(IsSessionInCallRequest) returns (IsSessionInCallReply) {}
rpc ProxySession(stream ClientSessionMessage) returns (stream ServerSessionMessage) {}
}
message LookupResumeIdRequest {
string resumeId = 1;
}
message LookupResumeIdReply {
string sessionId = 1;
}
message LookupSessionIdRequest {
@ -49,3 +59,11 @@ message IsSessionInCallRequest {
message IsSessionInCallReply {
bool inCall = 1;
}
message ClientSessionMessage {
bytes message = 1;
}
message ServerSessionMessage {
bytes message = 1;
}

131
hub.go
View file

@ -155,6 +155,7 @@ type Hub struct {
anonymousSessions map[*ClientSession]time.Time
expectHelloClients map[HandlerClient]time.Time
dialoutSessions map[*ClientSession]bool
remoteSessions map[*RemoteSession]bool
backendTimeout time.Duration
backend *BackendClient
@ -343,6 +344,7 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
anonymousSessions: make(map[*ClientSession]time.Time),
expectHelloClients: make(map[HandlerClient]time.Time),
dialoutSessions: make(map[*ClientSession]bool),
remoteSessions: make(map[*RemoteSession]bool),
backendTimeout: backendTimeout,
backend: backend,
@ -584,6 +586,22 @@ func (h *Hub) GetSessionByPublicId(sessionId string) Session {
return session
}
func (h *Hub) GetSessionByResumeId(resumeId string) Session {
data := h.decodeSessionId(resumeId, privateSessionName)
if data == nil {
return nil
}
h.mu.RLock()
defer h.mu.RUnlock()
session := h.sessions[data.Sid]
if session != nil && session.PrivateId() != resumeId {
// Session was created on different server.
return nil
}
return session
}
func (h *Hub) GetDialoutSession(roomId string, backend *Backend) *ClientSession {
url := backend.Url()
@ -715,6 +733,30 @@ func (h *Hub) sendWelcome(client HandlerClient) {
client.SendMessage(h.getWelcomeMessage())
}
func (h *Hub) registerClient(client HandlerClient) uint64 {
sid := h.sid.Add(1)
for sid == 0 {
sid = h.sid.Add(1)
}
h.mu.Lock()
defer h.mu.Unlock()
h.clients[sid] = client
return sid
}
func (h *Hub) unregisterClient(sid uint64) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.clients, sid)
}
func (h *Hub) unregisterRemoteSession(session *RemoteSession) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.remoteSessions, session)
}
func (h *Hub) newSessionIdData(backend *Backend) *SessionIdData {
sid := h.sid.Add(1)
for sid == 0 {
@ -953,12 +995,97 @@ func (h *Hub) sendHelloResponse(session *ClientSession, message *ClientMessage)
return session.SendMessage(response)
}
type remoteClientInfo struct {
client *GrpcClient
response *LookupResumeIdReply
}
func (h *Hub) tryProxyResume(c HandlerClient, resumeId string, message *ClientMessage) bool {
client, ok := c.(*Client)
if !ok {
return false
}
var clients []*GrpcClient
if h.rpcClients != nil {
clients = h.rpcClients.GetClients()
}
if len(clients) == 0 {
return false
}
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer rpcCancel()
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(rpcCtx)
defer cancel()
var remoteClient atomic.Pointer[remoteClientInfo]
for _, c := range clients {
wg.Add(1)
go func(client *GrpcClient) {
defer wg.Done()
if client.IsSelf() {
return
}
response, err := client.LookupResumeId(ctx, resumeId)
if err != nil {
log.Printf("Could not lookup resume id %s on %s: %s", resumeId, client.Target(), err)
return
}
cancel()
remoteClient.CompareAndSwap(nil, &remoteClientInfo{
client: client,
response: response,
})
}(c)
}
wg.Wait()
if !client.IsConnected() {
// Client disconnected while checking message.
return false
}
info := remoteClient.Load()
if info == nil {
return false
}
rs, err := NewRemoteSession(h, client, info.client, info.response.SessionId)
if err != nil {
log.Printf("Could not create remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err)
return false
}
if err := rs.Start(message); err != nil {
rs.Close()
log.Printf("Could not start remote session %s on %s: %s", info.response.SessionId, info.client.Target(), err)
return false
}
log.Printf("Proxy session %s to %s", info.response.SessionId, info.client.Target())
h.mu.Lock()
defer h.mu.Unlock()
h.remoteSessions[rs] = true
delete(h.expectHelloClients, client)
return true
}
func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
resumeId := message.Hello.ResumeId
if resumeId != "" {
data := h.decodeSessionId(resumeId, privateSessionName)
if data == nil {
statsHubSessionResumeFailed.Inc()
if h.tryProxyResume(client, resumeId, message) {
return
}
client.SendMessage(message.NewErrorServerMessage(NoSuchSession))
return
}
@ -968,6 +1095,10 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
if !found || resumeId != session.PrivateId() {
h.mu.Unlock()
statsHubSessionResumeFailed.Inc()
if h.tryProxyResume(client, resumeId, message) {
return
}
client.SendMessage(message.NewErrorServerMessage(NoSuchSession))
return
}

View file

@ -274,13 +274,19 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
h.mu.Lock()
clients := len(h.clients)
sessions := len(h.sessions)
remoteSessions := len(h.remoteSessions)
h.mu.Unlock()
h.ru.Lock()
rooms := len(h.rooms)
h.ru.Unlock()
readActive := h.readPumpActive.Load()
writeActive := h.writePumpActive.Load()
if clients == 0 && rooms == 0 && sessions == 0 && readActive == 0 && writeActive == 0 {
if clients == 0 &&
rooms == 0 &&
sessions == 0 &&
remoteSessions == 0 &&
readActive == 0 &&
writeActive == 0 {
break
}
@ -289,7 +295,7 @@ func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
h.mu.Lock()
h.ru.Lock()
dumpGoroutines("", os.Stderr)
t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, readActive, writeActive, ctx.Err())
t.Errorf("Error waiting for clients %+v / rooms %+v / sessions %+v / remoteSessions %v / %d read / %d write to terminate: %s", h.clients, h.rooms, h.sessions, h.remoteSessions, readActive, writeActive, ctx.Err())
h.ru.Unlock()
h.mu.Unlock()
return
@ -1892,6 +1898,307 @@ func TestClientHelloResumeAndJoin(t *testing.T) {
}
}
func runGrpcProxyTest(t *testing.T, f func(hub1, hub2 *Hub, server1, server2 *httptest.Server)) {
t.Helper()
var hub1 *Hub
var hub2 *Hub
var server1 *httptest.Server
var server2 *httptest.Server
var router1 *mux.Router
var router2 *mux.Router
hub1, hub2, router1, router2, server1, server2 = CreateClusteredHubsForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
// Make sure all backends use the same server
if server1 == nil {
server1 = server
} else {
server = server1
}
config, err := getTestConfig(server)
if err != nil {
return nil, err
}
config.RemoveOption("backend", "allowed")
config.RemoveOption("backend", "secret")
config.AddOption("backend", "backends", "backend1")
config.AddOption("backend1", "url", server.URL)
config.AddOption("backend1", "secret", string(testBackendSecret))
config.AddOption("backend1", "sessionlimit", "1")
return config, nil
})
registerBackendHandlerUrl(t, router1, "/")
registerBackendHandlerUrl(t, router2, "/")
f(hub1, hub2, server1, server2)
}
func TestClientHelloResumeProxy(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
client1 := NewTestClient(t, server1, hub1)
defer client1.CloseWithBye()
if err := client1.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
if hello.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", hello.Hello)
}
}
client1.Close()
if err := client1.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
client2 := NewTestClient(t, server2, hub2)
defer client2.CloseWithBye()
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
t.Fatal(err)
}
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if hello2.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
}
if hello2.Hello.SessionId != hello.Hello.SessionId {
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
}
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
}
}
// Join room by id.
roomId := "test-room"
if room, err := client2.JoinRoom(ctx, roomId); err != nil {
t.Fatal(err)
} else if room.Room.RoomId != roomId {
t.Fatalf("Expected room %s, got %s", roomId, room.Room.RoomId)
}
// We will receive a "joined" event.
if err := client2.RunUntilJoined(ctx, hello.Hello); err != nil {
t.Error(err)
}
if room := hub1.getRoom(roomId); room == nil {
t.Fatalf("Could not find room %s", roomId)
}
if room := hub2.getRoom(roomId); room != nil {
t.Fatalf("Should not have gotten room %s, got %+v", roomId, room)
}
users := []map[string]interface{}{
{
"sessionId": "the-session-id",
"inCall": 1,
},
}
room := hub1.getRoom(roomId)
if room == nil {
t.Fatalf("Could not find room %s", roomId)
}
room.PublishUsersInCallChanged(users, users)
if err := checkReceiveClientEvent(ctx, client2, "update", nil); err != nil {
t.Error(err)
}
})
})
}
func TestClientHelloResumeProxy_Takeover(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
client1 := NewTestClient(t, server1, hub1)
defer client1.CloseWithBye()
if err := client1.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
if hello.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", hello.Hello)
}
}
client2 := NewTestClient(t, server2, hub2)
defer client2.CloseWithBye()
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
t.Fatal(err)
}
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if hello2.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
}
if hello2.Hello.SessionId != hello.Hello.SessionId {
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
}
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
}
}
// The first client got disconnected with a reason in a "Bye" message.
if msg, err := client1.RunUntilMessage(ctx); err != nil {
t.Error(err)
} else {
if msg.Type != "bye" || msg.Bye == nil {
t.Errorf("Expected bye message, got %+v", msg)
} else if msg.Bye.Reason != "session_resumed" {
t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason)
}
}
if msg, err := client1.RunUntilMessage(ctx); err == nil {
t.Errorf("Expected error but received %+v", msg)
} else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
t.Errorf("Expected close error but received %+v", err)
}
client3 := NewTestClient(t, server1, hub1)
defer client3.CloseWithBye()
if err := client3.SendHelloResume(hello.Hello.ResumeId); err != nil {
t.Fatal(err)
}
hello3, err := client3.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if hello3.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
}
if hello3.Hello.SessionId != hello.Hello.SessionId {
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
}
if hello3.Hello.ResumeId != hello.Hello.ResumeId {
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
}
}
// The second client got disconnected with a reason in a "Bye" message.
if msg, err := client2.RunUntilMessage(ctx); err != nil {
t.Error(err)
} else {
if msg.Type != "bye" || msg.Bye == nil {
t.Errorf("Expected bye message, got %+v", msg)
} else if msg.Bye.Reason != "session_resumed" {
t.Errorf("Expected reason \"session_resumed\", got %+v", msg.Bye.Reason)
}
}
if msg, err := client2.RunUntilMessage(ctx); err == nil {
t.Errorf("Expected error but received %+v", msg)
} else if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
t.Errorf("Expected close error but received %+v", err)
}
})
})
}
func TestClientHelloResumeProxy_Disconnect(t *testing.T) {
ensureNoGoroutinesLeak(t, func(t *testing.T) {
runGrpcProxyTest(t, func(hub1, hub2 *Hub, server1, server2 *httptest.Server) {
client1 := NewTestClient(t, server1, hub1)
defer client1.CloseWithBye()
if err := client1.SendHello(testDefaultUserId); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
hello, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
if hello.Hello.ResumeId == "" {
t.Errorf("Expected resume id, got %+v", hello.Hello)
}
}
client1.Close()
if err := client1.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
client2 := NewTestClient(t, server2, hub2)
defer client2.CloseWithBye()
if err := client2.SendHelloResume(hello.Hello.ResumeId); err != nil {
t.Fatal(err)
}
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Error(err)
} else {
if hello2.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello2.Hello)
}
if hello2.Hello.SessionId != hello.Hello.SessionId {
t.Errorf("Expected session id %s, got %+v", hello.Hello.SessionId, hello2.Hello)
}
if hello2.Hello.ResumeId != hello.Hello.ResumeId {
t.Errorf("Expected resume id %s, got %+v", hello.Hello.ResumeId, hello2.Hello)
}
}
// Simulate unclean shutdown of second instance.
hub2.rpcServer.conn.Stop()
if err := client2.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
})
})
}
func TestClientHelloClient(t *testing.T) {
hub, _, _, server := CreateHubForTest(t)

152
remotesession.go Normal file
View file

@ -0,0 +1,152 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"sync/atomic"
"time"
)
type RemoteSession struct {
hub *Hub
client *Client
remoteClient *GrpcClient
sessionId string
proxy atomic.Pointer[SessionProxy]
}
func NewRemoteSession(hub *Hub, client *Client, remoteClient *GrpcClient, sessionId string) (*RemoteSession, error) {
remoteSession := &RemoteSession{
hub: hub,
client: client,
remoteClient: remoteClient,
sessionId: sessionId,
}
client.SetSessionId(sessionId)
client.SetHandler(remoteSession)
proxy, err := remoteClient.ProxySession(context.Background(), sessionId, remoteSession)
if err != nil {
return nil, err
}
remoteSession.proxy.Store(proxy)
return remoteSession, nil
}
func (s *RemoteSession) Country() string {
return s.client.Country()
}
func (s *RemoteSession) RemoteAddr() string {
return s.client.RemoteAddr()
}
func (s *RemoteSession) UserAgent() string {
return s.client.UserAgent()
}
func (s *RemoteSession) IsConnected() bool {
return true
}
func (s *RemoteSession) Start(message *ClientMessage) error {
return s.sendMessage(message)
}
func (s *RemoteSession) OnProxyMessage(msg *ServerSessionMessage) error {
var message *ServerMessage
if err := json.Unmarshal(msg.Message, &message); err != nil {
return err
}
if !s.client.SendMessage(message) {
return fmt.Errorf("could not send message to client")
}
return nil
}
func (s *RemoteSession) OnProxyClose(err error) {
if err != nil {
log.Printf("Proxy connection for session %s to %s was closed with error: %s", s.sessionId, s.remoteClient.Target(), err)
}
s.Close()
}
func (s *RemoteSession) SendMessage(message WritableClientMessage) bool {
return s.sendMessage(message) == nil
}
func (s *RemoteSession) sendProxyMessage(message []byte) error {
proxy := s.proxy.Load()
if proxy == nil {
return errors.New("proxy already closed")
}
msg := &ClientSessionMessage{
Message: message,
}
return proxy.Send(msg)
}
func (s *RemoteSession) sendMessage(message interface{}) error {
data, err := json.Marshal(message)
if err != nil {
return err
}
return s.sendProxyMessage(data)
}
func (s *RemoteSession) Close() {
if proxy := s.proxy.Swap(nil); proxy != nil {
proxy.Close()
}
s.hub.unregisterRemoteSession(s)
s.client.Close()
}
func (s *RemoteSession) OnLookupCountry(client HandlerClient) string {
return s.hub.OnLookupCountry(client)
}
func (s *RemoteSession) OnClosed(client HandlerClient) {
s.Close()
}
func (s *RemoteSession) OnMessageReceived(client HandlerClient, message []byte) {
if err := s.sendProxyMessage(message); err != nil {
log.Printf("Error sending %s to the proxy for session %s: %s", string(message), s.sessionId, err)
s.Close()
}
}
func (s *RemoteSession) OnRTTReceived(client HandlerClient, rtt time.Duration) {
}