Move grpc server code to "grpc" package.

This commit is contained in:
Joachim Bauch 2026-01-09 09:34:42 +01:00
commit be8353a54b
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
14 changed files with 1444 additions and 664 deletions

View file

@ -221,7 +221,7 @@ func main() {
}
}()
rpcServer, err := server.NewGrpcServer(stopCtx, cfg, version)
rpcServer, err := grpc.NewServer(stopCtx, cfg, version)
if err != nil {
logger.Fatalf("Could not create RPC server: %s", err)
}

View file

@ -172,3 +172,112 @@ func Test_GrpcClients_DnsDiscoveryInitialFailed(t *testing.T) {
assert.True(clients[0].ip.Equal(ip1), "Expected IP %s, got %s", ip1, clients[0].ip)
}
}
func Test_GrpcClients_EtcdInitial(t *testing.T) { // nolint:paralleltest
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
test.EnsureNoGoroutinesLeak(t, func(t *testing.T) {
_, addr1 := NewServerForTest(t)
_, addr2 := NewServerForTest(t)
embedEtcd := etcdtest.NewServerForTest(t)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
client, _ := NewClientsWithEtcdForTest(t, embedEtcd, nil)
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.NoError(t, client.WaitForInitialized(ctx))
clients := client.GetClients()
assert.Len(t, clients, 2, "Expected two clients, got %+v", clients)
})
}
func Test_GrpcClients_EtcdUpdate(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
embedEtcd := etcdtest.NewServerForTest(t)
client, _ := NewClientsWithEtcdForTest(t, embedEtcd, nil)
ch := client.GetWakeupChannelForTesting()
ctx, cancel := context.WithTimeout(ctx, testTimeout)
defer cancel()
assert.Empty(client.GetClients())
test.DrainWakeupChannel(ch)
_, addr1 := NewServerForTest(t)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
test.DrainWakeupChannel(ch)
_, addr2 := NewServerForTest(t)
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 2) {
assert.Equal(addr1, clients[0].Target())
assert.Equal(addr2, clients[1].Target())
}
test.DrainWakeupChannel(ch)
embedEtcd.DeleteValue("/grpctargets/one")
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr2, clients[0].Target())
}
test.DrainWakeupChannel(ch)
_, addr3 := NewServerForTest(t)
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr3+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr3, clients[0].Target())
}
}
func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
embedEtcd := etcdtest.NewServerForTest(t)
client, _ := NewClientsWithEtcdForTest(t, embedEtcd, nil)
ch := client.GetWakeupChannelForTesting()
ctx, cancel := context.WithTimeout(ctx, testTimeout)
defer cancel()
assert.Empty(client.GetClients())
test.DrainWakeupChannel(ch)
_, addr1 := NewServerForTest(t)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
test.DrainWakeupChannel(ch)
server2, addr2 := NewServerForTest(t)
server2.serverId = ServerId
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
waitForEvent(ctx, t, ch)
client.WaitForSelfCheck()
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
test.DrainWakeupChannel(ch)
embedEtcd.DeleteValue("/grpctargets/two")
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
}

View file

@ -19,7 +19,7 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package server
package grpc
import (
"context"
@ -37,35 +37,35 @@ import (
"github.com/strukturag/nextcloud-spreed-signaling/api"
"github.com/strukturag/nextcloud-spreed-signaling/config"
rpc "github.com/strukturag/nextcloud-spreed-signaling/grpc"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/sfu"
"github.com/strukturag/nextcloud-spreed-signaling/talk"
)
var (
ErrNoProxyMcu = errors.New("no proxy mcu")
)
func init() {
RegisterGrpcServerStats()
RegisterServerStats()
}
type GrpcServerHub interface {
GetSessionByResumeId(resumeId api.PrivateSessionId) Session
GetSessionByPublicId(sessionId api.PublicSessionId) Session
type ServerHub interface {
GetSessionIdByResumeId(resumeId api.PrivateSessionId) api.PublicSessionId
GetSessionIdByRoomSessionId(roomSessionId api.RoomSessionId) (api.PublicSessionId, error)
GetRoomForBackend(roomId string, backend *talk.Backend) *Room
IsSessionIdInCall(sessionId api.PublicSessionId, roomId string, backendUrl string) (bool, bool)
DisconnectSessionByRoomSessionId(sessionId api.PublicSessionId, roomSessionId api.RoomSessionId, reason string)
GetBackend(u *url.URL) *talk.Backend
CreateProxyToken(publisherId string) (string, error)
GetInternalSessions(roomId string, backend *talk.Backend) ([]*InternalSessionData, []*VirtualSessionData, bool)
GetTransientEntries(roomId string, backend *talk.Backend) (api.TransientDataEntries, bool)
GetPublisherIdForSessionId(ctx context.Context, sessionId api.PublicSessionId, streamType sfu.StreamType) (*GetPublisherIdReply, error)
ProxySession(request RpcSessions_ProxySessionServer) error
}
type GrpcServer struct {
rpc.UnimplementedRpcBackendServer
rpc.UnimplementedRpcInternalServer
rpc.UnimplementedRpcMcuServer
rpc.UnimplementedRpcSessionsServer
type Server struct {
UnimplementedRpcBackendServer
UnimplementedRpcInternalServer
UnimplementedRpcMcuServer
UnimplementedRpcSessionsServer
logger log.Logger
version string
@ -74,10 +74,10 @@ type GrpcServer struct {
listener net.Listener
serverId string // can be overwritten from tests
hub GrpcServerHub
hub ServerHub
}
func NewGrpcServer(ctx context.Context, cfg *goconf.ConfigFile, version string) (*GrpcServer, error) {
func NewServer(ctx context.Context, cfg *goconf.ConfigFile, version string) (*Server, error) {
var listener net.Listener
if addr, _ := config.GetStringOptionWithEnv(cfg, "grpc", "listen"); addr != "" {
var err error
@ -88,28 +88,36 @@ func NewGrpcServer(ctx context.Context, cfg *goconf.ConfigFile, version string)
}
logger := log.LoggerFromContext(ctx)
creds, err := rpc.NewReloadableCredentials(logger, cfg, true)
creds, err := NewReloadableCredentials(logger, cfg, true)
if err != nil {
return nil, err
}
conn := grpc.NewServer(grpc.Creds(creds))
result := &GrpcServer{
result := &Server{
logger: logger,
version: version,
creds: creds,
conn: conn,
listener: listener,
serverId: rpc.ServerId,
serverId: ServerId,
}
rpc.RegisterRpcBackendServer(conn, result)
rpc.RegisterRpcInternalServer(conn, result)
rpc.RegisterRpcSessionsServer(conn, result)
rpc.RegisterRpcMcuServer(conn, result)
RegisterRpcBackendServer(conn, result)
RegisterRpcInternalServer(conn, result)
RegisterRpcSessionsServer(conn, result)
RegisterRpcMcuServer(conn, result)
return result, nil
}
func (s *GrpcServer) Run() error {
func (s *Server) SetHub(hub ServerHub) {
s.hub = hub
}
func (s *Server) SetServerId(serverId string) {
s.serverId = serverId
}
func (s *Server) Run() error {
if s.listener == nil {
return nil
}
@ -121,28 +129,32 @@ type SimpleCloser interface {
Close()
}
func (s *GrpcServer) Close() {
func (s *Server) Close() {
s.conn.GracefulStop()
if cr, ok := s.creds.(SimpleCloser); ok {
cr.Close()
}
}
func (s *GrpcServer) LookupResumeId(ctx context.Context, request *rpc.LookupResumeIdRequest) (*rpc.LookupResumeIdReply, error) {
func (s *Server) CloseUnclean() {
s.conn.Stop()
}
func (s *Server) LookupResumeId(ctx context.Context, request *LookupResumeIdRequest) (*LookupResumeIdReply, error) {
statsGrpcServerCalls.WithLabelValues("LookupResumeId").Inc()
// TODO: Remove debug logging
s.logger.Printf("Lookup session for resume id %s", request.ResumeId)
session := s.hub.GetSessionByResumeId(api.PrivateSessionId(request.ResumeId))
if session == nil {
sessionId := s.hub.GetSessionIdByResumeId(api.PrivateSessionId(request.ResumeId))
if sessionId == "" {
return nil, status.Error(codes.NotFound, "no such room session id")
}
return &rpc.LookupResumeIdReply{
SessionId: string(session.PublicId()),
return &LookupResumeIdReply{
SessionId: string(sessionId),
}, nil
}
func (s *GrpcServer) LookupSessionId(ctx context.Context, request *rpc.LookupSessionIdRequest) (*rpc.LookupSessionIdReply, error) {
func (s *Server) LookupSessionId(ctx context.Context, request *LookupSessionIdRequest) (*LookupSessionIdReply, error) {
statsGrpcServerCalls.WithLabelValues("LookupSessionId").Inc()
// TODO: Remove debug logging
s.logger.Printf("Lookup session id for room session id %s", request.RoomSessionId)
@ -154,45 +166,30 @@ func (s *GrpcServer) LookupSessionId(ctx context.Context, request *rpc.LookupSes
}
if sid != "" && request.DisconnectReason != "" {
if session := s.hub.GetSessionByPublicId(api.PublicSessionId(sid)); session != nil {
s.logger.Printf("Closing session %s because same room session %s connected", session.PublicId(), request.RoomSessionId)
session.LeaveRoom(false)
switch sess := session.(type) {
case *ClientSession:
if client := sess.GetClient(); client != nil {
client.SendByeResponseWithReason(nil, "room_session_reconnected")
}
}
session.Close()
}
s.hub.DisconnectSessionByRoomSessionId(sid, api.RoomSessionId(request.RoomSessionId), request.DisconnectReason)
}
return &rpc.LookupSessionIdReply{
return &LookupSessionIdReply{
SessionId: string(sid),
}, nil
}
func (s *GrpcServer) IsSessionInCall(ctx context.Context, request *rpc.IsSessionInCallRequest) (*rpc.IsSessionInCallReply, error) {
func (s *Server) IsSessionInCall(ctx context.Context, request *IsSessionInCallRequest) (*IsSessionInCallReply, error) {
statsGrpcServerCalls.WithLabelValues("IsSessionInCall").Inc()
// TODO: Remove debug logging
s.logger.Printf("Check if session %s is in call %s on %s", request.SessionId, request.RoomId, request.BackendUrl)
session := s.hub.GetSessionByPublicId(api.PublicSessionId(request.SessionId))
if session == nil {
found, inCall := s.hub.IsSessionIdInCall(api.PublicSessionId(request.SessionId), request.GetRoomId(), request.GetBackendUrl())
if !found {
return nil, status.Error(codes.NotFound, "no such session id")
}
result := &rpc.IsSessionInCallReply{}
room := session.GetRoom()
if room == nil || room.Id() != request.GetRoomId() || !room.Backend().HasUrl(request.GetBackendUrl()) ||
(session.ClientType() != api.HelloClientTypeInternal && !room.IsSessionInCall(session)) {
// Recipient is not in a room, a different room or not in the call.
result.InCall = false
} else {
result.InCall = true
result := &IsSessionInCallReply{
InCall: inCall,
}
return result, nil
}
func (s *GrpcServer) GetInternalSessions(ctx context.Context, request *rpc.GetInternalSessionsRequest) (*rpc.GetInternalSessionsReply, error) {
func (s *Server) GetInternalSessions(ctx context.Context, request *GetInternalSessionsRequest) (*GetInternalSessionsReply, error) {
statsGrpcServerCalls.WithLabelValues("GetInternalSessions").Inc()
// TODO: Remove debug logging
s.logger.Printf("Get internal sessions from %s on %v (fallback %s)", request.RoomId, request.BackendUrls, request.BackendUrl) // nolint
@ -207,7 +204,7 @@ func (s *GrpcServer) GetInternalSessions(ctx context.Context, request *rpc.GetIn
backendUrls = []string{""}
}
result := &rpc.GetInternalSessionsReply{}
result := &GetInternalSessionsReply{}
processed := make(map[string]bool)
for _, bu := range backendUrls {
var parsed *url.URL
@ -230,81 +227,35 @@ func (s *GrpcServer) GetInternalSessions(ctx context.Context, request *rpc.GetIn
}
processed[backend.Id()] = true
room := s.hub.GetRoomForBackend(request.RoomId, backend)
if room == nil {
internalSessions, virtualSessions, found := s.hub.GetInternalSessions(request.RoomId, backend)
if !found {
return nil, status.Error(codes.NotFound, "no such room")
}
room.mu.RLock()
defer room.mu.RUnlock()
for session := range room.internalSessions {
result.InternalSessions = append(result.InternalSessions, &rpc.InternalSessionData{
SessionId: string(session.PublicId()),
InCall: uint32(session.GetInCall()),
Features: session.GetFeatures(),
})
}
for session := range room.virtualSessions {
result.VirtualSessions = append(result.VirtualSessions, &rpc.VirtualSessionData{
SessionId: string(session.PublicId()),
InCall: uint32(session.GetInCall()),
})
}
result.InternalSessions = append(result.InternalSessions, internalSessions...)
result.VirtualSessions = append(result.VirtualSessions, virtualSessions...)
}
return result, nil
}
func (s *GrpcServer) GetPublisherId(ctx context.Context, request *rpc.GetPublisherIdRequest) (*rpc.GetPublisherIdReply, error) {
func (s *Server) GetPublisherId(ctx context.Context, request *GetPublisherIdRequest) (*GetPublisherIdReply, error) {
statsGrpcServerCalls.WithLabelValues("GetPublisherId").Inc()
// TODO: Remove debug logging
s.logger.Printf("Get %s publisher id for session %s", request.StreamType, request.SessionId)
session := s.hub.GetSessionByPublicId(api.PublicSessionId(request.SessionId))
if session == nil {
return nil, status.Error(codes.NotFound, "no such session")
}
clientSession, ok := session.(*ClientSession)
if !ok {
return nil, status.Error(codes.NotFound, "no such session")
}
publisher := clientSession.GetOrWaitForPublisher(ctx, sfu.StreamType(request.StreamType))
if publisher, ok := publisher.(sfu.PublisherWithConnectionUrlAndIP); ok {
connUrl, ip := publisher.GetConnectionURL()
reply := &rpc.GetPublisherIdReply{
PublisherId: publisher.Id(),
ProxyUrl: connUrl,
}
if len(ip) > 0 {
reply.Ip = ip.String()
}
var err error
if reply.ConnectToken, err = s.hub.CreateProxyToken(""); err != nil && !errors.Is(err, ErrNoProxyMcu) {
s.logger.Printf("Error creating proxy token for connection: %s", err)
return nil, status.Error(codes.Internal, "error creating proxy connect token")
}
if reply.PublisherToken, err = s.hub.CreateProxyToken(publisher.Id()); err != nil && !errors.Is(err, ErrNoProxyMcu) {
s.logger.Printf("Error creating proxy token for publisher %s: %s", publisher.Id(), err)
return nil, status.Error(codes.Internal, "error creating proxy publisher token")
}
return reply, nil
}
return nil, status.Error(codes.NotFound, "no such publisher")
return s.hub.GetPublisherIdForSessionId(ctx, api.PublicSessionId(request.SessionId), sfu.StreamType(request.StreamType))
}
func (s *GrpcServer) GetServerId(ctx context.Context, request *rpc.GetServerIdRequest) (*rpc.GetServerIdReply, error) {
func (s *Server) GetServerId(ctx context.Context, request *GetServerIdRequest) (*GetServerIdReply, error) {
statsGrpcServerCalls.WithLabelValues("GetServerId").Inc()
return &rpc.GetServerIdReply{
return &GetServerIdReply{
ServerId: s.serverId,
Version: s.version,
}, nil
}
func (s *GrpcServer) GetTransientData(ctx context.Context, request *rpc.GetTransientDataRequest) (*rpc.GetTransientDataReply, error) {
func (s *Server) GetTransientData(ctx context.Context, request *GetTransientDataRequest) (*GetTransientDataReply, error) {
statsGrpcServerCalls.WithLabelValues("GetTransientData").Inc()
backendUrls := request.BackendUrls
@ -313,7 +264,7 @@ func (s *GrpcServer) GetTransientData(ctx context.Context, request *rpc.GetTrans
backendUrls = []string{""}
}
result := &rpc.GetTransientDataReply{}
result := &GetTransientDataReply{}
processed := make(map[string]bool)
for _, bu := range backendUrls {
var parsed *url.URL
@ -336,21 +287,18 @@ func (s *GrpcServer) GetTransientData(ctx context.Context, request *rpc.GetTrans
}
processed[backend.Id()] = true
room := s.hub.GetRoomForBackend(request.RoomId, backend)
if room == nil {
entries, found := s.hub.GetTransientEntries(request.RoomId, backend)
if !found {
return nil, status.Error(codes.NotFound, "no such room")
}
entries := room.transientData.GetEntries()
if len(entries) == 0 {
} else if len(entries) == 0 {
return nil, status.Error(codes.NotFound, "room has no transient data")
}
if result.Entries == nil {
result.Entries = make(map[string]*rpc.GrpcTransientDataEntry)
result.Entries = make(map[string]*GrpcTransientDataEntry)
}
for k, v := range entries {
e := &rpc.GrpcTransientDataEntry{}
e := &GrpcTransientDataEntry{}
var err error
if e.Value, err = json.Marshal(v.Value); err != nil {
return nil, status.Errorf(codes.Internal, "error marshalling data: %s", err)
@ -365,7 +313,7 @@ func (s *GrpcServer) GetTransientData(ctx context.Context, request *rpc.GetTrans
return result, nil
}
func (s *GrpcServer) GetSessionCount(ctx context.Context, request *rpc.GetSessionCountRequest) (*rpc.GetSessionCountReply, error) {
func (s *Server) GetSessionCount(ctx context.Context, request *GetSessionCountRequest) (*GetSessionCountReply, error) {
statsGrpcServerCalls.WithLabelValues("SessionCount").Inc()
u, err := url.Parse(request.Url)
@ -378,25 +326,13 @@ func (s *GrpcServer) GetSessionCount(ctx context.Context, request *rpc.GetSessio
return nil, status.Error(codes.NotFound, "no such backend")
}
return &rpc.GetSessionCountReply{
return &GetSessionCountReply{
Count: uint32(backend.Len()),
}, nil
}
func (s *GrpcServer) ProxySession(request rpc.RpcSessions_ProxySessionServer) error {
func (s *Server) ProxySession(request RpcSessions_ProxySessionServer) error {
statsGrpcServerCalls.WithLabelValues("ProxySession").Inc()
hub, ok := s.hub.(*Hub)
if !ok {
return status.Error(codes.Internal, "invalid hub type")
}
client, err := newRemoteGrpcClient(hub, request)
if err != nil {
return err
}
sid := hub.registerClient(client)
defer hub.unregisterClient(sid)
return client.run()
return s.hub.ProxySession(request)
}

View file

@ -19,7 +19,7 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package server
package grpc
import (
"github.com/prometheus/client_golang/prometheus"
@ -40,6 +40,6 @@ var (
}
)
func RegisterGrpcServerStats() {
func RegisterServerStats() {
metrics.RegisterAll(grpcServerStats...)
}

853
grpc/server_test.go Normal file
View file

@ -0,0 +1,853 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package grpc
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"net"
"net/url"
"path"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
status "google.golang.org/grpc/status"
"github.com/strukturag/nextcloud-spreed-signaling/api"
"github.com/strukturag/nextcloud-spreed-signaling/geoip"
"github.com/strukturag/nextcloud-spreed-signaling/internal"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/sfu"
"github.com/strukturag/nextcloud-spreed-signaling/talk"
"github.com/strukturag/nextcloud-spreed-signaling/test"
)
type CertificateReloadWaiter interface {
WaitForCertificateReload(ctx context.Context, counter uint64) error
}
func (s *Server) WaitForCertificateReload(ctx context.Context, counter uint64) error {
c, ok := s.creds.(CertificateReloadWaiter)
if !ok {
return errors.New("no reloadable credentials found")
}
return c.WaitForCertificateReload(ctx, counter)
}
type CertPoolReloadWaiter interface {
WaitForCertPoolReload(ctx context.Context, counter uint64) error
}
func (s *Server) WaitForCertPoolReload(ctx context.Context, counter uint64) error {
c, ok := s.creds.(CertPoolReloadWaiter)
if !ok {
return errors.New("no reloadable credentials found")
}
return c.WaitForCertPoolReload(ctx, counter)
}
func NewServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *Server, addr string) {
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
for port := 50000; port < 50100; port++ {
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
config.AddOption("grpc", "listen", addr)
var err error
server, err = NewServer(ctx, config, "0.0.0")
if test.IsErrorAddressAlreadyInUse(err) {
continue
}
require.NoError(t, err)
break
}
require.NotNil(t, server, "could not find free port")
// Don't match with own server id by default.
server.SetServerId("dont-match")
go func() {
assert.NoError(t, server.Run(), "could not start GRPC server")
}()
t.Cleanup(func() {
server.Close()
})
return server, addr
}
func NewServerForTest(t *testing.T) (server *Server, addr string) {
config := goconf.NewConfigFile()
return NewServerForTestWithConfig(t, config)
}
func TestServer_ReloadCerts(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
key, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
org1 := "Testing certificate"
cert1 := internal.GenerateSelfSignedCertificateForTesting(t, org1, key)
dir := t.TempDir()
privkeyFile := path.Join(dir, "privkey.pem")
pubkeyFile := path.Join(dir, "pubkey.pem")
certFile := path.Join(dir, "cert.pem")
require.NoError(internal.WritePrivateKey(key, privkeyFile))
require.NoError(internal.WritePublicKey(&key.PublicKey, pubkeyFile))
require.NoError(internal.WriteCertificate(cert1, certFile))
config := goconf.NewConfigFile()
config.AddOption("grpc", "servercertificate", certFile)
config.AddOption("grpc", "serverkey", privkeyFile)
server, addr := NewServerForTestWithConfig(t, config)
cp1 := x509.NewCertPool()
cp1.AddCert(cert1)
cfg1 := &tls.Config{
RootCAs: cp1,
}
conn1, err := tls.Dial("tcp", addr, cfg1)
require.NoError(err)
defer conn1.Close() // nolint
state1 := conn1.ConnectionState()
if certs := state1.PeerCertificates; assert.NotEmpty(certs) {
if assert.NotEmpty(certs[0].Subject.Organization) {
assert.Equal(org1, certs[0].Subject.Organization[0])
}
}
org2 := "Updated certificate"
cert2 := internal.GenerateSelfSignedCertificateForTesting(t, org2, key)
internal.ReplaceCertificate(t, certFile, cert2)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
require.NoError(server.WaitForCertificateReload(ctx, 0))
cp2 := x509.NewCertPool()
cp2.AddCert(cert2)
cfg2 := &tls.Config{
RootCAs: cp2,
}
conn2, err := tls.Dial("tcp", addr, cfg2)
require.NoError(err)
defer conn2.Close() // nolint
state2 := conn2.ConnectionState()
if certs := state2.PeerCertificates; assert.NotEmpty(certs) {
if assert.NotEmpty(certs[0].Subject.Organization) {
assert.Equal(org2, certs[0].Subject.Organization[0])
}
}
}
func TestServer_ReloadCA(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
require := require.New(t)
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
serverCert := internal.GenerateSelfSignedCertificateForTesting(t, "Server cert", serverKey)
org1 := "Testing client"
clientCert1 := internal.GenerateSelfSignedCertificateForTesting(t, org1, clientKey)
dir := t.TempDir()
privkeyFile := path.Join(dir, "privkey.pem")
pubkeyFile := path.Join(dir, "pubkey.pem")
certFile := path.Join(dir, "cert.pem")
caFile := path.Join(dir, "ca.pem")
require.NoError(internal.WritePrivateKey(serverKey, privkeyFile))
require.NoError(internal.WritePublicKey(&serverKey.PublicKey, pubkeyFile))
require.NoError(internal.WriteCertificate(serverCert, certFile))
require.NoError(internal.WriteCertificate(clientCert1, caFile))
config := goconf.NewConfigFile()
config.AddOption("grpc", "servercertificate", certFile)
config.AddOption("grpc", "serverkey", privkeyFile)
config.AddOption("grpc", "clientca", caFile)
server, addr := NewServerForTestWithConfig(t, config)
pool := x509.NewCertPool()
pool.AddCert(serverCert)
pair1, err := tls.X509KeyPair(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: clientCert1.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
}))
require.NoError(err)
cfg1 := &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{pair1},
}
client1, err := NewClient(logger, addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1)))
require.NoError(err)
defer client1.Close() // nolint
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
_, _, err = client1.GetServerId(ctx1)
require.NoError(err)
org2 := "Updated client"
clientCert2 := internal.GenerateSelfSignedCertificateForTesting(t, org2, clientKey)
internal.ReplaceCertificate(t, caFile, clientCert2)
require.NoError(server.WaitForCertPoolReload(ctx1, 0))
pair2, err := tls.X509KeyPair(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: clientCert2.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
}))
require.NoError(err)
cfg2 := &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{pair2},
}
client2, err := NewClient(logger, addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2)))
require.NoError(err)
defer client2.Close() // nolint
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
// This will fail if the CA certificate has not been reloaded by the server.
_, _, err = client2.GetServerId(ctx2)
require.NoError(err)
}
func TestClients_Encryption(t *testing.T) { // nolint:paralleltest
test.EnsureNoGoroutinesLeak(t, func(t *testing.T) {
require := require.New(t)
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
serverCert := internal.GenerateSelfSignedCertificateForTesting(t, "Server cert", serverKey)
clientCert := internal.GenerateSelfSignedCertificateForTesting(t, "Testing client", clientKey)
dir := t.TempDir()
serverPrivkeyFile := path.Join(dir, "server-privkey.pem")
serverPubkeyFile := path.Join(dir, "server-pubkey.pem")
serverCertFile := path.Join(dir, "server-cert.pem")
require.NoError(internal.WritePrivateKey(serverKey, serverPrivkeyFile))
require.NoError(internal.WritePublicKey(&serverKey.PublicKey, serverPubkeyFile))
require.NoError(internal.WriteCertificate(serverCert, serverCertFile))
clientPrivkeyFile := path.Join(dir, "client-privkey.pem")
clientPubkeyFile := path.Join(dir, "client-pubkey.pem")
clientCertFile := path.Join(dir, "client-cert.pem")
require.NoError(internal.WritePrivateKey(clientKey, clientPrivkeyFile))
require.NoError(internal.WritePublicKey(&clientKey.PublicKey, clientPubkeyFile))
require.NoError(internal.WriteCertificate(clientCert, clientCertFile))
serverConfig := goconf.NewConfigFile()
serverConfig.AddOption("grpc", "servercertificate", serverCertFile)
serverConfig.AddOption("grpc", "serverkey", serverPrivkeyFile)
serverConfig.AddOption("grpc", "clientca", clientCertFile)
_, addr := NewServerForTestWithConfig(t, serverConfig)
clientConfig := goconf.NewConfigFile()
clientConfig.AddOption("grpc", "targets", addr)
clientConfig.AddOption("grpc", "clientcertificate", clientCertFile)
clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile)
clientConfig.AddOption("grpc", "serverca", serverCertFile)
clients, _ := NewClientsForTestWithConfig(t, clientConfig, nil, nil)
ctx, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
_, _, err := client.GetServerId(ctx)
require.NoError(err)
}
})
}
type disconnectInfo struct {
sessionId api.PublicSessionId
roomSessionId api.RoomSessionId
reason string
}
type testServerHub struct {
t *testing.T
backend *talk.Backend
disconnected atomic.Pointer[disconnectInfo]
}
func newTestServerHub(t *testing.T) *testServerHub {
t.Helper()
logger := log.NewLoggerForTest(t)
cfg := goconf.NewConfigFile()
cfg.AddOption(testBackendId, "secret", "not-so-secret")
cfg.AddOption(testBackendId, "sessionlimit", "10")
backend, err := talk.NewBackendFromConfig(logger, testBackendId, cfg, "foo")
require.NoError(t, err)
u, err := url.Parse(testBackendUrl)
require.NoError(t, err)
backend.AddUrl(u)
return &testServerHub{
t: t,
backend: backend,
}
}
const (
testResumeId = "test-resume-id"
testSessionId = "test-session-id"
testRoomSessionId = "test-room-session-id"
testInternalSessionId = "test-internal-session-id"
testVirtualSessionId = "test-virtual-session-id"
testInternalInCallFlags = 2
testVirtualInCallFlags = 3
testBackendId = "backend-1"
testBackendUrl = "https://server.domain.invalid"
testRoomId = "test-room-id"
testStreamType = sfu.StreamTypeVideo
testProxyUrl = "https://proxy.domain.invalid"
testIp = "1.2.3.4"
testConnectToken = "test-connection-token"
testPublisherToken = "test-publisher-token"
testAddr = "2.3.4.5"
testCountry = geoip.Country("DE")
testAgent = "test-agent"
)
var (
testFeatures = []string{"bar", "foo"}
testExpires = time.Now().Add(time.Minute).Truncate(time.Millisecond)
testMessage = []byte("hello world!")
)
func (h *testServerHub) GetSessionIdByResumeId(resumeId api.PrivateSessionId) api.PublicSessionId {
if resumeId == testResumeId {
return testSessionId
}
return ""
}
func (h *testServerHub) GetSessionIdByRoomSessionId(roomSessionId api.RoomSessionId) (api.PublicSessionId, error) {
if roomSessionId == testRoomSessionId {
return testSessionId, nil
}
return "", ErrNoSuchRoomSession
}
func (h *testServerHub) IsSessionIdInCall(sessionId api.PublicSessionId, roomId string, backendUrl string) (bool, bool) {
if roomId == testRoomId && backendUrl == testBackendUrl {
return sessionId == testSessionId, true
}
return false, false
}
func (h *testServerHub) DisconnectSessionByRoomSessionId(sessionId api.PublicSessionId, roomSessionId api.RoomSessionId, reason string) {
h.t.Helper()
prev := h.disconnected.Swap(&disconnectInfo{
sessionId: sessionId,
roomSessionId: roomSessionId,
reason: reason,
})
assert.Nil(h.t, prev, "duplicate call")
}
func (h *testServerHub) GetBackend(u *url.URL) *talk.Backend {
if u == nil {
// No compat backend.
return nil
} else if u.String() == testBackendUrl {
return h.backend
}
return nil
}
func (h *testServerHub) GetInternalSessions(roomId string, backend *talk.Backend) ([]*InternalSessionData, []*VirtualSessionData, bool) {
if roomId == testRoomId && backend == h.backend {
return []*InternalSessionData{
{
SessionId: testInternalSessionId,
InCall: testInternalInCallFlags,
Features: testFeatures,
},
}, []*VirtualSessionData{
{
SessionId: testVirtualSessionId,
InCall: testVirtualInCallFlags,
},
}, true
}
return nil, nil, false
}
func (h *testServerHub) GetTransientEntries(roomId string, backend *talk.Backend) (api.TransientDataEntries, bool) {
if roomId == testRoomId && backend == h.backend {
return api.TransientDataEntries{
"foo": api.NewTransientDataEntryWithExpires("bar", testExpires),
"bar": api.NewTransientDataEntry(123, 0),
}, true
}
return nil, false
}
func (h *testServerHub) GetPublisherIdForSessionId(ctx context.Context, sessionId api.PublicSessionId, streamType sfu.StreamType) (*GetPublisherIdReply, error) {
if sessionId == testSessionId {
if streamType != testStreamType {
return nil, status.Error(codes.NotFound, "no such publisher")
}
return &GetPublisherIdReply{
PublisherId: testSessionId,
ProxyUrl: testProxyUrl,
Ip: testIp,
ConnectToken: testConnectToken,
PublisherToken: testPublisherToken,
}, nil
}
return nil, status.Error(codes.NotFound, "no such session")
}
func getMetadata(t *testing.T, md metadata.MD, key string) string {
t.Helper()
if values := md.Get(key); len(values) > 0 {
return values[0]
}
return ""
}
func (h *testServerHub) ProxySession(request RpcSessions_ProxySessionServer) error {
h.t.Helper()
if md, found := metadata.FromIncomingContext(request.Context()); assert.True(h.t, found) {
if getMetadata(h.t, md, "sessionId") != testSessionId {
return status.Error(codes.InvalidArgument, "unknown session id")
}
assert.Equal(h.t, testSessionId, getMetadata(h.t, md, "sessionId"))
assert.Equal(h.t, testAddr, getMetadata(h.t, md, "remoteAddr"))
assert.EqualValues(h.t, testCountry, getMetadata(h.t, md, "country"))
assert.Equal(h.t, testAgent, getMetadata(h.t, md, "userAgent"))
}
assert.NoError(h.t, request.Send(&ServerSessionMessage{
Message: testMessage,
}))
return nil
}
func TestServer_GetSessionIdByResumeId(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
reply, err := client.LookupResumeId(ctx, "")
assert.ErrorIs(err, ErrNoSuchResumeId, "expected unknown resume id, got %s", reply.GetSessionId())
reply, err = client.LookupResumeId(ctx, testResumeId+"1")
assert.ErrorIs(err, ErrNoSuchResumeId, "expected unknown resume id, got %s", reply.GetSessionId())
if reply, err := client.LookupResumeId(ctx, testResumeId); assert.NoError(err) {
assert.Equal(testSessionId, reply.SessionId)
}
}
}
func TestServer_LookupSessionId(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
sessionId, err := client.LookupSessionId(ctx, "", "")
assert.ErrorIs(err, ErrNoSuchRoomSession, "expected unknown room session id, got %s", sessionId)
sessionId, err = client.LookupSessionId(ctx, testRoomSessionId+"1", "")
assert.ErrorIs(err, ErrNoSuchRoomSession, "expected unknown room session id, got %s", sessionId)
if sessionId, err := client.LookupSessionId(ctx, testRoomSessionId, "test-reason"); assert.NoError(err) {
assert.EqualValues(testSessionId, sessionId)
}
}
if disconnected := hub.disconnected.Load(); assert.NotNil(disconnected, "session was not disconnected") {
assert.EqualValues(testSessionId, disconnected.sessionId)
assert.EqualValues(testRoomSessionId, disconnected.roomSessionId)
assert.Equal("test-reason", disconnected.reason)
}
}
func TestServer_IsSessionInCall(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
if inCall, err := client.IsSessionInCall(ctx, testSessionId, testRoomId+"1", testBackendUrl); assert.NoError(err) {
assert.False(inCall)
}
if inCall, err := client.IsSessionInCall(ctx, testSessionId, testRoomId, testBackendUrl+"1"); assert.NoError(err) {
assert.False(inCall)
}
if inCall, err := client.IsSessionInCall(ctx, testSessionId+"1", testRoomId, testBackendUrl); assert.NoError(err) {
assert.False(inCall, "should not be in call")
}
if inCall, err := client.IsSessionInCall(ctx, testSessionId, testRoomId, testBackendUrl); assert.NoError(err) {
assert.True(inCall, "should be in call")
}
}
}
func TestServer_GetInternalSessions(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
if internal, virtual, err := client.GetInternalSessions(ctx, testRoomId+"1", []string{testBackendUrl}); assert.NoError(err) {
assert.Empty(internal)
assert.Empty(virtual)
}
if internal, virtual, err := client.GetInternalSessions(ctx, testRoomId, nil); assert.NoError(err) {
assert.Empty(internal)
assert.Empty(virtual)
}
if internal, virtual, err := client.GetInternalSessions(ctx, testRoomId, []string{testBackendUrl}); assert.NoError(err) {
if assert.Len(internal, 1) && assert.NotNil(internal[testInternalSessionId], "did not find %s in %+v", testInternalSessionId, internal) {
assert.Equal(testInternalSessionId, internal[testInternalSessionId].SessionId)
assert.EqualValues(testInternalInCallFlags, internal[testInternalSessionId].InCall)
assert.Equal(testFeatures, internal[testInternalSessionId].Features)
}
if assert.Len(virtual, 1) && assert.NotNil(virtual[testVirtualSessionId], "did not find %s in %+v", testVirtualSessionId, virtual) {
assert.Equal(testVirtualSessionId, virtual[testVirtualSessionId].SessionId)
assert.EqualValues(testVirtualInCallFlags, virtual[testVirtualSessionId].InCall)
}
}
}
}
func TestServer_GetPublisherId(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
if publisherId, proxyUrl, ip, connToken, publisherToken, err := client.GetPublisherId(ctx, testSessionId, sfu.StreamTypeVideo); assert.NoError(err) {
assert.EqualValues(testSessionId, publisherId)
assert.Equal(testProxyUrl, proxyUrl)
assert.True(net.ParseIP(testIp).Equal(ip), "expected IP %s, got %s", testIp, ip.String())
assert.Equal(testConnectToken, connToken)
assert.Equal(testPublisherToken, publisherToken)
}
}
}
func TestServer_GetTransientData(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
if entries, err := client.GetTransientData(ctx, testRoomId+"1", hub.backend); assert.NoError(err) {
assert.Empty(entries)
}
if entries, err := client.GetTransientData(ctx, testRoomId, hub.backend); assert.NoError(err) && assert.Len(entries, 2) {
if e := entries["foo"]; assert.NotNil(e, "did not find foo in %+v", entries) {
assert.Equal("bar", e.Value)
assert.Equal(testExpires, e.Expires)
}
if e := entries["bar"]; assert.NotNil(e, "did not find bar in %+v", entries) {
assert.EqualValues(123, e.Value)
assert.True(e.Expires.IsZero(), "should have no expiration, got %s", e.Expires)
}
}
}
}
type testReceiver struct {
t *testing.T
received atomic.Bool
closed chan struct{}
}
func (r *testReceiver) RemoteAddr() string {
return testAddr
}
func (r *testReceiver) Country() geoip.Country {
return testCountry
}
func (r *testReceiver) UserAgent() string {
return testAgent
}
func (r *testReceiver) OnProxyMessage(message *ServerSessionMessage) error {
assert.Equal(r.t, testMessage, message.Message)
assert.False(r.t, r.received.Swap(true), "received additional message %v", message)
return nil
}
func (r *testReceiver) OnProxyClose(err error) {
if err != nil {
if s := status.Convert(err); assert.NotNil(r.t, s, "expected status, got %+v", err) {
assert.Equal(r.t, codes.InvalidArgument, s.Code())
assert.Equal(r.t, "unknown session id", s.Message())
}
}
close(r.closed)
}
func TestServer_ProxySession(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
receiver := &testReceiver{
t: t,
closed: make(chan struct{}),
}
if proxy, err := client.ProxySession(ctx, testSessionId, receiver); assert.NoError(err) {
t.Cleanup(func() {
assert.NoError(proxy.Close())
})
assert.NotNil(proxy)
<-receiver.closed
assert.True(receiver.received.Load(), "should have received message")
}
}
}
func TestServer_ProxySessionError(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
receiver := &testReceiver{
t: t,
closed: make(chan struct{}),
}
if proxy, err := client.ProxySession(ctx, testSessionId+"1", receiver); assert.NoError(err) {
t.Cleanup(func() {
assert.NoError(proxy.Close())
})
assert.NotNil(proxy)
<-receiver.closed
}
}
}
type testSession struct{}
func (s *testSession) PublicId() api.PublicSessionId {
return testSessionId
}
func (s *testSession) ClientType() api.ClientType {
return api.HelloClientTypeClient
}
func TestServer_GetSessionCount(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
hub := newTestServerHub(t)
server, addr := NewServerForTest(t)
server.SetHub(hub)
clients, _ := NewClientsForTest(t, addr, nil)
ctx, cancel := context.WithTimeout(t.Context(), testTimeout)
defer cancel()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
if count, err := client.GetSessionCount(ctx, testBackendUrl+"1"); assert.NoError(err) {
assert.EqualValues(0, count)
}
if count, err := client.GetSessionCount(ctx, testBackendUrl); assert.NoError(err) {
assert.EqualValues(0, count)
}
assert.NoError(hub.backend.AddSession(&testSession{}))
if count, err := client.GetSessionCount(ctx, testBackendUrl); assert.NoError(err) {
assert.EqualValues(1, count)
}
hub.backend.RemoveSession(&testSession{})
if count, err := client.GetSessionCount(ctx, testBackendUrl); assert.NoError(err) {
assert.EqualValues(0, count)
}
}
}

56
grpc/test/client_test.go Normal file
View file

@ -0,0 +1,56 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2026 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/strukturag/nextcloud-spreed-signaling/etcd/etcdtest"
)
func TestClientsWithEtcd(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
serverId := "the-test-server-id"
server, addr := NewServerForTest(t)
server.SetServerId(serverId)
etcd := etcdtest.NewServerForTest(t)
etcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr+"\"}"))
clients, _ := NewClientsWithEtcdForTest(t, etcd, nil)
require.NoError(clients.WaitForInitialized(t.Context()))
for _, client := range clients.GetClients() {
if id, version, err := client.GetServerId(t.Context()); assert.NoError(err) {
assert.Equal(serverId, id)
assert.NotEmpty(version)
}
}
}

72
grpc/test/server.go Normal file
View file

@ -0,0 +1,72 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2026 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package test
import (
"net"
"strconv"
"testing"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/strukturag/nextcloud-spreed-signaling/grpc"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/test"
)
func NewServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *grpc.Server, addr string) {
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
for port := 50000; port < 50100; port++ {
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
config.AddOption("grpc", "listen", addr)
var err error
server, err = grpc.NewServer(ctx, config, "0.0.0")
if test.IsErrorAddressAlreadyInUse(err) {
continue
}
require.NoError(t, err)
break
}
require.NotNil(t, server, "could not find free port")
// Don't match with own server id by default.
server.SetServerId("dont-match")
go func() {
assert.NoError(t, server.Run(), "could not start GRPC server")
}()
t.Cleanup(func() {
server.Close()
})
return server, addr
}
func NewServerForTest(t *testing.T) (server *grpc.Server, addr string) {
config := goconf.NewConfigFile()
return NewServerForTestWithConfig(t, config)
}

51
grpc/test/server_test.go Normal file
View file

@ -0,0 +1,51 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2026 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServer(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
serverId := "the-test-server-id"
server, addr := NewServerForTest(t)
server.SetServerId(serverId)
clients, _ := NewClientsForTest(t, addr, nil)
require.NoError(clients.WaitForInitialized(t.Context()))
for _, client := range clients.GetClients() {
if id, version, err := client.GetServerId(t.Context()); assert.NoError(err) {
assert.Equal(serverId, id)
assert.NotEmpty(version)
}
}
}

View file

@ -150,8 +150,8 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
})
nats, _ := nats.StartLocalServer(t)
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer1, addr1 := grpctest.NewServerForTest(t)
grpcServer2, addr2 := grpctest.NewServerForTest(t)
if config1 == nil {
config1 = goconf.NewConfigFile()

View file

@ -1,157 +0,0 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/strukturag/nextcloud-spreed-signaling/etcd/etcdtest"
"github.com/strukturag/nextcloud-spreed-signaling/grpc"
grpctest "github.com/strukturag/nextcloud-spreed-signaling/grpc/test"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/test"
)
func waitForEvent(ctx context.Context, t *testing.T, ch <-chan struct{}) {
t.Helper()
select {
case <-ch:
return
case <-ctx.Done():
assert.Fail(t, "timeout waiting for event")
}
}
func Test_GrpcClients_EtcdInitial(t *testing.T) { // nolint:paralleltest
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
test.EnsureNoGoroutinesLeak(t, func(t *testing.T) {
_, addr1 := NewGrpcServerForTest(t)
_, addr2 := NewGrpcServerForTest(t)
embedEtcd := etcdtest.NewServerForTest(t)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
client, _ := grpctest.NewClientsWithEtcdForTest(t, embedEtcd, nil)
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.NoError(t, client.WaitForInitialized(ctx))
clients := client.GetClients()
assert.Len(t, clients, 2, "Expected two clients, got %+v", clients)
})
}
func Test_GrpcClients_EtcdUpdate(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
embedEtcd := etcdtest.NewServerForTest(t)
client, _ := grpctest.NewClientsWithEtcdForTest(t, embedEtcd, nil)
ch := client.GetWakeupChannelForTesting()
ctx, cancel := context.WithTimeout(ctx, testTimeout)
defer cancel()
assert.Empty(client.GetClients())
test.DrainWakeupChannel(ch)
_, addr1 := NewGrpcServerForTest(t)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
test.DrainWakeupChannel(ch)
_, addr2 := NewGrpcServerForTest(t)
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 2) {
assert.Equal(addr1, clients[0].Target())
assert.Equal(addr2, clients[1].Target())
}
test.DrainWakeupChannel(ch)
embedEtcd.DeleteValue("/grpctargets/one")
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr2, clients[0].Target())
}
test.DrainWakeupChannel(ch)
_, addr3 := NewGrpcServerForTest(t)
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr3+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr3, clients[0].Target())
}
}
func Test_GrpcClients_EtcdIgnoreSelf(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
assert := assert.New(t)
embedEtcd := etcdtest.NewServerForTest(t)
client, _ := grpctest.NewClientsWithEtcdForTest(t, embedEtcd, nil)
ch := client.GetWakeupChannelForTesting()
ctx, cancel := context.WithTimeout(ctx, testTimeout)
defer cancel()
assert.Empty(client.GetClients())
test.DrainWakeupChannel(ch)
_, addr1 := NewGrpcServerForTest(t)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
test.DrainWakeupChannel(ch)
server2, addr2 := NewGrpcServerForTest(t)
server2.serverId = grpc.ServerId
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
waitForEvent(ctx, t, ch)
client.WaitForSelfCheck()
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
test.DrainWakeupChannel(ch)
embedEtcd.DeleteValue("/grpctargets/two")
waitForEvent(ctx, t, ch)
if clients := client.GetClients(); assert.Len(clients, 1) {
assert.Equal(addr1, clients[0].Target())
}
}

View file

@ -1,314 +0,0 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"net"
"path"
"strconv"
"testing"
"time"
"github.com/dlintw/goconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
rpc "github.com/strukturag/nextcloud-spreed-signaling/grpc"
grpctest "github.com/strukturag/nextcloud-spreed-signaling/grpc/test"
"github.com/strukturag/nextcloud-spreed-signaling/internal"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/test"
)
type CertificateReloadWaiter interface {
WaitForCertificateReload(ctx context.Context, counter uint64) error
}
func (s *GrpcServer) WaitForCertificateReload(ctx context.Context, counter uint64) error {
c, ok := s.creds.(CertificateReloadWaiter)
if !ok {
return errors.New("no reloadable credentials found")
}
return c.WaitForCertificateReload(ctx, counter)
}
type CertPoolReloadWaiter interface {
WaitForCertPoolReload(ctx context.Context, counter uint64) error
}
func (s *GrpcServer) WaitForCertPoolReload(ctx context.Context, counter uint64) error {
c, ok := s.creds.(CertPoolReloadWaiter)
if !ok {
return errors.New("no reloadable credentials found")
}
return c.WaitForCertPoolReload(ctx, counter)
}
func NewGrpcServerForTestWithConfig(t *testing.T, config *goconf.ConfigFile) (server *GrpcServer, addr string) {
logger := log.NewLoggerForTest(t)
ctx := log.NewLoggerContext(t.Context(), logger)
for port := 50000; port < 50100; port++ {
addr = net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
config.AddOption("grpc", "listen", addr)
var err error
server, err = NewGrpcServer(ctx, config, "0.0.0")
if test.IsErrorAddressAlreadyInUse(err) {
continue
}
require.NoError(t, err)
break
}
require.NotNil(t, server, "could not find free port")
// Don't match with own server id by default.
server.serverId = "dont-match"
go func() {
assert.NoError(t, server.Run(), "could not start GRPC server")
}()
t.Cleanup(func() {
server.Close()
})
return server, addr
}
func NewGrpcServerForTest(t *testing.T) (server *GrpcServer, addr string) {
config := goconf.NewConfigFile()
return NewGrpcServerForTestWithConfig(t, config)
}
func Test_GrpcServer_ReloadCerts(t *testing.T) {
t.Parallel()
require := require.New(t)
assert := assert.New(t)
key, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
org1 := "Testing certificate"
cert1 := internal.GenerateSelfSignedCertificateForTesting(t, org1, key)
dir := t.TempDir()
privkeyFile := path.Join(dir, "privkey.pem")
pubkeyFile := path.Join(dir, "pubkey.pem")
certFile := path.Join(dir, "cert.pem")
require.NoError(internal.WritePrivateKey(key, privkeyFile))
require.NoError(internal.WritePublicKey(&key.PublicKey, pubkeyFile))
require.NoError(internal.WriteCertificate(cert1, certFile))
config := goconf.NewConfigFile()
config.AddOption("grpc", "servercertificate", certFile)
config.AddOption("grpc", "serverkey", privkeyFile)
server, addr := NewGrpcServerForTestWithConfig(t, config)
cp1 := x509.NewCertPool()
cp1.AddCert(cert1)
cfg1 := &tls.Config{
RootCAs: cp1,
}
conn1, err := tls.Dial("tcp", addr, cfg1)
require.NoError(err)
defer conn1.Close() // nolint
state1 := conn1.ConnectionState()
if certs := state1.PeerCertificates; assert.NotEmpty(certs) {
if assert.NotEmpty(certs[0].Subject.Organization) {
assert.Equal(org1, certs[0].Subject.Organization[0])
}
}
org2 := "Updated certificate"
cert2 := internal.GenerateSelfSignedCertificateForTesting(t, org2, key)
internal.ReplaceCertificate(t, certFile, cert2)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
require.NoError(server.WaitForCertificateReload(ctx, 0))
cp2 := x509.NewCertPool()
cp2.AddCert(cert2)
cfg2 := &tls.Config{
RootCAs: cp2,
}
conn2, err := tls.Dial("tcp", addr, cfg2)
require.NoError(err)
defer conn2.Close() // nolint
state2 := conn2.ConnectionState()
if certs := state2.PeerCertificates; assert.NotEmpty(certs) {
if assert.NotEmpty(certs[0].Subject.Organization) {
assert.Equal(org2, certs[0].Subject.Organization[0])
}
}
}
func Test_GrpcServer_ReloadCA(t *testing.T) {
t.Parallel()
logger := log.NewLoggerForTest(t)
require := require.New(t)
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
serverCert := internal.GenerateSelfSignedCertificateForTesting(t, "Server cert", serverKey)
org1 := "Testing client"
clientCert1 := internal.GenerateSelfSignedCertificateForTesting(t, org1, clientKey)
dir := t.TempDir()
privkeyFile := path.Join(dir, "privkey.pem")
pubkeyFile := path.Join(dir, "pubkey.pem")
certFile := path.Join(dir, "cert.pem")
caFile := path.Join(dir, "ca.pem")
require.NoError(internal.WritePrivateKey(serverKey, privkeyFile))
require.NoError(internal.WritePublicKey(&serverKey.PublicKey, pubkeyFile))
require.NoError(internal.WriteCertificate(serverCert, certFile))
require.NoError(internal.WriteCertificate(clientCert1, caFile))
config := goconf.NewConfigFile()
config.AddOption("grpc", "servercertificate", certFile)
config.AddOption("grpc", "serverkey", privkeyFile)
config.AddOption("grpc", "clientca", caFile)
server, addr := NewGrpcServerForTestWithConfig(t, config)
pool := x509.NewCertPool()
pool.AddCert(serverCert)
pair1, err := tls.X509KeyPair(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: clientCert1.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
}))
require.NoError(err)
cfg1 := &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{pair1},
}
client1, err := rpc.NewClient(logger, addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg1)))
require.NoError(err)
defer client1.Close() // nolint
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
_, _, err = client1.GetServerId(ctx1)
require.NoError(err)
org2 := "Updated client"
clientCert2 := internal.GenerateSelfSignedCertificateForTesting(t, org2, clientKey)
internal.ReplaceCertificate(t, caFile, clientCert2)
require.NoError(server.WaitForCertPoolReload(ctx1, 0))
pair2, err := tls.X509KeyPair(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: clientCert2.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
}))
require.NoError(err)
cfg2 := &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{pair2},
}
client2, err := rpc.NewClient(logger, addr, nil, grpc.WithTransportCredentials(credentials.NewTLS(cfg2)))
require.NoError(err)
defer client2.Close() // nolint
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
// This will fail if the CA certificate has not been reloaded by the server.
_, _, err = client2.GetServerId(ctx2)
require.NoError(err)
}
func Test_GrpcClients_Encryption(t *testing.T) { // nolint:paralleltest
test.EnsureNoGoroutinesLeak(t, func(t *testing.T) {
require := require.New(t)
serverKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
clientKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(err)
serverCert := internal.GenerateSelfSignedCertificateForTesting(t, "Server cert", serverKey)
clientCert := internal.GenerateSelfSignedCertificateForTesting(t, "Testing client", clientKey)
dir := t.TempDir()
serverPrivkeyFile := path.Join(dir, "server-privkey.pem")
serverPubkeyFile := path.Join(dir, "server-pubkey.pem")
serverCertFile := path.Join(dir, "server-cert.pem")
require.NoError(internal.WritePrivateKey(serverKey, serverPrivkeyFile))
require.NoError(internal.WritePublicKey(&serverKey.PublicKey, serverPubkeyFile))
require.NoError(internal.WriteCertificate(serverCert, serverCertFile))
clientPrivkeyFile := path.Join(dir, "client-privkey.pem")
clientPubkeyFile := path.Join(dir, "client-pubkey.pem")
clientCertFile := path.Join(dir, "client-cert.pem")
require.NoError(internal.WritePrivateKey(clientKey, clientPrivkeyFile))
require.NoError(internal.WritePublicKey(&clientKey.PublicKey, clientPubkeyFile))
require.NoError(internal.WriteCertificate(clientCert, clientCertFile))
serverConfig := goconf.NewConfigFile()
serverConfig.AddOption("grpc", "servercertificate", serverCertFile)
serverConfig.AddOption("grpc", "serverkey", serverPrivkeyFile)
serverConfig.AddOption("grpc", "clientca", clientCertFile)
_, addr := NewGrpcServerForTestWithConfig(t, serverConfig)
clientConfig := goconf.NewConfigFile()
clientConfig.AddOption("grpc", "targets", addr)
clientConfig.AddOption("grpc", "clientcertificate", clientCertFile)
clientConfig.AddOption("grpc", "clientkey", clientPrivkeyFile)
clientConfig.AddOption("grpc", "serverca", serverCertFile)
clients, _ := grpctest.NewClientsForTestWithConfig(t, clientConfig, nil, nil)
ctx, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
require.NoError(clients.WaitForInitialized(ctx))
for _, client := range clients.GetClients() {
_, _, err := client.GetServerId(ctx)
require.NoError(err)
}
})
}

View file

@ -50,6 +50,8 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/strukturag/nextcloud-spreed-signaling/api"
"github.com/strukturag/nextcloud-spreed-signaling/async"
@ -93,6 +95,8 @@ var (
// TooManyRequests is returned if brute force detection reports too many failed "hello" requests.
TooManyRequests = api.NewError("too_many_requests", "Too many requests.")
ErrNoProxyTokenSupported = errors.New("proxy token generation not supported")
// Maximum number of concurrent requests to a backend.
defaultMaxConcurrentRequestsPerHost = 8
@ -225,7 +229,7 @@ type Hub struct {
geoipUpdating atomic.Bool
etcdClient etcd.Client
rpcServer *GrpcServer
rpcServer *grpc.Server
rpcClients *grpc.Clients
throttler async.Throttler
@ -237,7 +241,7 @@ type Hub struct {
blockedCandidates atomic.Pointer[container.IPList]
}
func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events events.AsyncEvents, rpcServer *GrpcServer, rpcClients *grpc.Clients, etcdClient etcd.Client, r *mux.Router, version string) (*Hub, error) {
func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events events.AsyncEvents, rpcServer *grpc.Server, rpcClients *grpc.Clients, etcdClient etcd.Client, r *mux.Router, version string) (*Hub, error) {
logger := log.LoggerFromContext(ctx)
hashKey, _ := config.GetStringOptionWithEnv(cfg, "sessions", "hashkey")
switch len(hashKey) {
@ -468,7 +472,7 @@ func NewHub(ctx context.Context, cfg *goconf.ConfigFile, events events.AsyncEven
})
roomPing.hub = hub
if rpcServer != nil {
rpcServer.hub = hub
rpcServer.SetHub(hub)
}
hub.upgrader.CheckOrigin = hub.checkOrigin
r.HandleFunc("/spreed", func(w http.ResponseWriter, r *http.Request) {
@ -750,10 +754,72 @@ func (h *Hub) GetSessionByResumeId(resumeId api.PrivateSessionId) Session {
return session
}
func (h *Hub) GetSessionIdByResumeId(resumeId api.PrivateSessionId) api.PublicSessionId {
session := h.GetSessionByResumeId(resumeId)
if session == nil {
return ""
}
return session.PublicId()
}
func (h *Hub) GetSessionIdByRoomSessionId(roomSessionId api.RoomSessionId) (api.PublicSessionId, error) {
return h.roomSessions.GetSessionId(roomSessionId)
}
func (h *Hub) IsSessionIdInCall(sessionId api.PublicSessionId, roomId string, backendUrl string) (bool, bool) {
session := h.GetSessionByPublicId(sessionId)
if session == nil {
return false, false
}
inCall := true
room := session.GetRoom()
if room == nil || room.Id() != roomId || !room.Backend().HasUrl(backendUrl) ||
(session.ClientType() != api.HelloClientTypeInternal && !room.IsSessionInCall(session)) {
// Recipient is not in a room, a different room or not in the call.
inCall = false
}
return inCall, true
}
func (h *Hub) GetPublisherIdForSessionId(ctx context.Context, sessionId api.PublicSessionId, streamType sfu.StreamType) (*grpc.GetPublisherIdReply, error) {
session := h.GetSessionByPublicId(sessionId)
if session == nil {
return nil, status.Error(codes.NotFound, "no such session")
}
clientSession, ok := session.(*ClientSession)
if !ok {
return nil, status.Error(codes.NotFound, "no such session")
}
publisher := clientSession.GetOrWaitForPublisher(ctx, streamType)
if publisher, ok := publisher.(sfu.PublisherWithConnectionUrlAndIP); ok {
connUrl, ip := publisher.GetConnectionURL()
reply := &grpc.GetPublisherIdReply{
PublisherId: publisher.Id(),
ProxyUrl: connUrl,
}
if len(ip) > 0 {
reply.Ip = ip.String()
}
var err error
if reply.ConnectToken, err = h.CreateProxyToken(""); err != nil && !errors.Is(err, ErrNoProxyTokenSupported) {
h.logger.Printf("Error creating proxy token for connection: %s", err)
return nil, status.Error(codes.Internal, "error creating proxy connect token")
}
if reply.PublisherToken, err = h.CreateProxyToken(publisher.Id()); err != nil && !errors.Is(err, ErrNoProxyTokenSupported) {
h.logger.Printf("Error creating proxy token for publisher %s: %s", publisher.Id(), err)
return nil, status.Error(codes.Internal, "error creating proxy publisher token")
}
return reply, nil
}
return nil, status.Error(codes.NotFound, "no such publisher")
}
func (h *Hub) GetDialoutSessions(roomId string, backend *talk.Backend) (result []*ClientSession) {
h.mu.RLock()
defer h.mu.RUnlock()
@ -780,7 +846,7 @@ func (h *Hub) GetBackend(u *url.URL) *talk.Backend {
func (h *Hub) CreateProxyToken(publisherId string) (string, error) {
withToken, ok := h.mcu.(sfu.WithToken)
if !ok {
return "", ErrNoProxyMcu
return "", ErrNoProxyTokenSupported
}
return withToken.CreateToken(publisherId)
@ -1648,11 +1714,25 @@ func (h *Hub) disconnectByRoomSessionId(ctx context.Context, roomSessionId api.R
}
h.logger.Printf("Closing session %s because same room session %s connected", session.PublicId(), roomSessionId)
h.disconnectSessionWithReason(session, "room_session_reconnected")
}
func (h *Hub) DisconnectSessionByRoomSessionId(sessionId api.PublicSessionId, roomSessionId api.RoomSessionId, reason string) {
session := h.GetSessionByPublicId(sessionId)
if session == nil {
return
}
h.logger.Printf("Closing session %s because same room session %s connected", session.PublicId(), roomSessionId)
h.disconnectSessionWithReason(session, reason)
}
func (h *Hub) disconnectSessionWithReason(session Session, reason string) {
session.LeaveRoom(false)
switch sess := session.(type) {
case *ClientSession:
if client := sess.GetClient(); client != nil {
client.SendByeResponseWithReason(nil, "room_session_reconnected")
client.SendByeResponseWithReason(nil, reason)
}
}
session.Close()
@ -1981,6 +2061,45 @@ func (h *Hub) GetRoomForBackend(id string, backend *talk.Backend) *Room {
return h.rooms[internalRoomId]
}
func (h *Hub) GetInternalSessions(roomId string, backend *talk.Backend) ([]*grpc.InternalSessionData, []*grpc.VirtualSessionData, bool) {
room := h.GetRoomForBackend(roomId, backend)
if room == nil {
return nil, nil, false
}
room.mu.RLock()
defer room.mu.RUnlock()
var internalSessions []*grpc.InternalSessionData
var virtualSessions []*grpc.VirtualSessionData
for session := range room.internalSessions {
internalSessions = append(internalSessions, &grpc.InternalSessionData{
SessionId: string(session.PublicId()),
InCall: uint32(session.GetInCall()),
Features: session.GetFeatures(),
})
}
for session := range room.virtualSessions {
virtualSessions = append(virtualSessions, &grpc.VirtualSessionData{
SessionId: string(session.PublicId()),
InCall: uint32(session.GetInCall()),
})
}
return internalSessions, virtualSessions, true
}
func (h *Hub) GetTransientEntries(roomId string, backend *talk.Backend) (api.TransientDataEntries, bool) {
room := h.GetRoomForBackend(roomId, backend)
if room == nil {
return nil, false
}
entries := room.transientData.GetEntries()
return entries, true
}
func (h *Hub) removeRoom(room *Room) {
internalRoomId := getRoomIdForBackend(room.Id(), room.Backend())
h.ru.Lock()
@ -3120,6 +3239,18 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
client.ReadPump()
}
func (h *Hub) ProxySession(request grpc.RpcSessions_ProxySessionServer) error {
client, err := newRemoteGrpcClient(h, request)
if err != nil {
return err
}
sid := h.registerClient(client)
defer h.unregisterClient(sid)
return client.run()
}
func (h *Hub) LookupCountry(addr string) geoip.Country {
ip := net.ParseIP(addr)
if ip == nil {

View file

@ -34,11 +34,15 @@ import (
"github.com/stretchr/testify/require"
"github.com/strukturag/nextcloud-spreed-signaling/api"
"github.com/strukturag/nextcloud-spreed-signaling/etcd/etcdtest"
"github.com/strukturag/nextcloud-spreed-signaling/grpc"
grpctest "github.com/strukturag/nextcloud-spreed-signaling/grpc/test"
"github.com/strukturag/nextcloud-spreed-signaling/sfu"
"github.com/strukturag/nextcloud-spreed-signaling/sfu/mock"
proxytest "github.com/strukturag/nextcloud-spreed-signaling/sfu/proxy/test"
"github.com/strukturag/nextcloud-spreed-signaling/sfu/proxy/testserver"
"github.com/strukturag/nextcloud-spreed-signaling/talk"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type mockGrpcServerHub struct {
@ -56,6 +60,13 @@ func (h *mockGrpcServerHub) setProxy(t *testing.T, proxy sfu.SFU) {
h.proxy.Store(&wt)
}
func (h *mockGrpcServerHub) getSession(sessionId api.PublicSessionId) Session {
h.sessionsLock.Lock()
defer h.sessionsLock.Unlock()
return h.sessionByPublicId[sessionId]
}
func (h *mockGrpcServerHub) addSession(session *ClientSession) {
h.sessionsLock.Lock()
defer h.sessionsLock.Unlock()
@ -71,35 +82,67 @@ func (h *mockGrpcServerHub) removeSession(session *ClientSession) {
delete(h.sessionByPublicId, session.PublicId())
}
func (h *mockGrpcServerHub) GetSessionByResumeId(resumeId api.PrivateSessionId) Session {
return nil
}
func (h *mockGrpcServerHub) GetSessionByPublicId(sessionId api.PublicSessionId) Session {
h.sessionsLock.Lock()
defer h.sessionsLock.Unlock()
return h.sessionByPublicId[sessionId]
func (h *mockGrpcServerHub) GetSessionIdByResumeId(resumeId api.PrivateSessionId) api.PublicSessionId {
return ""
}
func (h *mockGrpcServerHub) GetSessionIdByRoomSessionId(roomSessionId api.RoomSessionId) (api.PublicSessionId, error) {
return "", nil
}
func (h *mockGrpcServerHub) IsSessionIdInCall(sessionId api.PublicSessionId, roomId string, backendUrl string) (bool, bool) {
return false, false
}
func (h *mockGrpcServerHub) DisconnectSessionByRoomSessionId(sessionId api.PublicSessionId, roomSessionId api.RoomSessionId, reason string) {
}
func (h *mockGrpcServerHub) GetBackend(u *url.URL) *talk.Backend {
return nil
}
func (h *mockGrpcServerHub) GetRoomForBackend(roomId string, backend *talk.Backend) *Room {
return nil
func (h *mockGrpcServerHub) GetInternalSessions(roomId string, backend *talk.Backend) ([]*grpc.InternalSessionData, []*grpc.VirtualSessionData, bool) {
return nil, nil, false
}
func (h *mockGrpcServerHub) CreateProxyToken(publisherId string) (string, error) {
proxy := h.proxy.Load()
if proxy == nil {
return "", errors.New("not a proxy mcu")
func (h *mockGrpcServerHub) GetTransientEntries(roomId string, backend *talk.Backend) (api.TransientDataEntries, bool) {
return nil, false
}
func (h *mockGrpcServerHub) GetPublisherIdForSessionId(ctx context.Context, sessionId api.PublicSessionId, streamType sfu.StreamType) (*grpc.GetPublisherIdReply, error) {
session := h.getSession(sessionId)
if session == nil {
return nil, status.Error(codes.NotFound, "no such session")
}
return (*proxy).CreateToken(publisherId)
clientSession, ok := session.(*ClientSession)
if !ok {
return nil, status.Error(codes.NotFound, "no such session")
}
publisher := clientSession.GetOrWaitForPublisher(ctx, streamType)
if publisher, ok := publisher.(sfu.PublisherWithConnectionUrlAndIP); ok {
connUrl, ip := publisher.GetConnectionURL()
reply := &grpc.GetPublisherIdReply{
PublisherId: publisher.Id(),
ProxyUrl: connUrl,
}
if len(ip) > 0 {
reply.Ip = ip.String()
}
if proxy := h.proxy.Load(); proxy != nil {
reply.ConnectToken, _ = (*proxy).CreateToken("")
reply.PublisherToken, _ = (*proxy).CreateToken(publisher.Id())
}
return reply, nil
}
return nil, status.Error(codes.NotFound, "no such publisher")
}
func (h *mockGrpcServerHub) ProxySession(request grpc.RpcSessions_ProxySessionServer) error {
return errors.New("not implemented")
}
func Test_ProxyRemotePublisher(t *testing.T) {
@ -107,13 +150,13 @@ func Test_ProxyRemotePublisher(t *testing.T) {
embedEtcd := etcdtest.NewServerForTest(t)
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer1, addr1 := grpctest.NewServerForTest(t)
grpcServer2, addr2 := grpctest.NewServerForTest(t)
hub1 := &mockGrpcServerHub{}
hub2 := &mockGrpcServerHub{}
grpcServer1.hub = hub1
grpcServer2.hub = hub2
grpcServer1.SetHub(hub1)
grpcServer2.SetHub(hub2)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
@ -178,16 +221,16 @@ func Test_ProxyMultipleRemotePublisher(t *testing.T) {
embedEtcd := etcdtest.NewServerForTest(t)
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer3, addr3 := NewGrpcServerForTest(t)
grpcServer1, addr1 := grpctest.NewServerForTest(t)
grpcServer2, addr2 := grpctest.NewServerForTest(t)
grpcServer3, addr3 := grpctest.NewServerForTest(t)
hub1 := &mockGrpcServerHub{}
hub2 := &mockGrpcServerHub{}
hub3 := &mockGrpcServerHub{}
grpcServer1.hub = hub1
grpcServer2.hub = hub2
grpcServer3.hub = hub3
grpcServer1.SetHub(hub1)
grpcServer2.SetHub(hub2)
grpcServer3.SetHub(hub3)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
@ -272,13 +315,13 @@ func Test_ProxyRemotePublisherWait(t *testing.T) {
embedEtcd := etcdtest.NewServerForTest(t)
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer1, addr1 := grpctest.NewServerForTest(t)
grpcServer2, addr2 := grpctest.NewServerForTest(t)
hub1 := &mockGrpcServerHub{}
hub2 := &mockGrpcServerHub{}
grpcServer1.hub = hub1
grpcServer2.hub = hub2
grpcServer1.SetHub(hub1)
grpcServer2.SetHub(hub2)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
@ -360,13 +403,13 @@ func Test_ProxyRemotePublisherTemporary(t *testing.T) {
assert := assert.New(t)
embedEtcd := etcdtest.NewServerForTest(t)
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer1, addr1 := grpctest.NewServerForTest(t)
grpcServer2, addr2 := grpctest.NewServerForTest(t)
hub1 := &mockGrpcServerHub{}
hub2 := &mockGrpcServerHub{}
grpcServer1.hub = hub1
grpcServer2.hub = hub2
grpcServer1.SetHub(hub1)
grpcServer2.SetHub(hub2)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
@ -465,13 +508,13 @@ func Test_ProxyConnectToken(t *testing.T) {
embedEtcd := etcdtest.NewServerForTest(t)
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer1, addr1 := grpctest.NewServerForTest(t)
grpcServer2, addr2 := grpctest.NewServerForTest(t)
hub1 := &mockGrpcServerHub{}
hub2 := &mockGrpcServerHub{}
grpcServer1.hub = hub1
grpcServer2.hub = hub2
grpcServer1.SetHub(hub1)
grpcServer2.SetHub(hub2)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))
@ -537,13 +580,13 @@ func Test_ProxyPublisherToken(t *testing.T) {
embedEtcd := etcdtest.NewServerForTest(t)
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer1, addr1 := grpctest.NewServerForTest(t)
grpcServer2, addr2 := grpctest.NewServerForTest(t)
hub1 := &mockGrpcServerHub{}
hub2 := &mockGrpcServerHub{}
grpcServer1.hub = hub1
grpcServer2.hub = hub2
grpcServer1.SetHub(hub1)
grpcServer2.SetHub(hub2)
embedEtcd.SetValue("/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}"))
embedEtcd.SetValue("/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}"))

View file

@ -242,8 +242,8 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
} else {
nats2 = nats1
}
grpcServer1, addr1 := NewGrpcServerForTest(t)
grpcServer2, addr2 := NewGrpcServerForTest(t)
grpcServer1, addr1 := grpctest.NewServerForTest(t)
grpcServer2, addr2 := grpctest.NewServerForTest(t)
if strings.Contains(t.Name(), "Federation") {
// Signaling servers should not form a cluster in federation tests.
@ -1905,7 +1905,7 @@ func TestClientHelloResumeProxy_Disconnect(t *testing.T) { // nolint:paralleltes
assert.Equal(hello.Hello.ResumeId, hello2.Hello.ResumeId, "%+v", hello2.Hello)
// Simulate unclean shutdown of second instance.
hub2.rpcServer.conn.Stop()
hub2.rpcServer.CloseUnclean()
assert.NoError(client2.WaitForClientRemoved(ctx))
})