mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-06-08 00:42:25 +02:00
Merge pull request #296 from strukturag/clustered-session-limit
Implement per-backend session limit for clusters.
This commit is contained in:
commit
51fb410c28
1
Makefile
1
Makefile
|
@ -119,6 +119,7 @@ common_easyjson: \
|
|||
api_signaling_easyjson.go
|
||||
|
||||
common_proto: \
|
||||
grpc_backend.pb.go \
|
||||
grpc_internal.pb.go \
|
||||
grpc_mcu.pb.go \
|
||||
grpc_sessions.pb.go
|
||||
|
|
|
@ -44,6 +44,7 @@ var (
|
|||
type Backend struct {
|
||||
id string
|
||||
url string
|
||||
parsedUrl *url.URL
|
||||
secret []byte
|
||||
compat bool
|
||||
|
||||
|
@ -80,6 +81,24 @@ func (b *Backend) IsUrlAllowed(u *url.URL) bool {
|
|||
}
|
||||
}
|
||||
|
||||
func (b *Backend) Url() string {
|
||||
return b.url
|
||||
}
|
||||
|
||||
func (b *Backend) ParsedUrl() *url.URL {
|
||||
return b.parsedUrl
|
||||
}
|
||||
|
||||
func (b *Backend) Limit() int {
|
||||
return int(b.sessionLimit)
|
||||
}
|
||||
|
||||
func (b *Backend) Len() int {
|
||||
b.sessionsLock.Lock()
|
||||
defer b.sessionsLock.Unlock()
|
||||
return len(b.sessions)
|
||||
}
|
||||
|
||||
func (b *Backend) AddSession(session Session) error {
|
||||
if session.ClientType() == HelloClientTypeInternal || session.ClientType() == HelloClientTypeVirtual {
|
||||
// Internal and virtual sessions are not counting to the limit.
|
||||
|
|
|
@ -165,6 +165,7 @@ func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data
|
|||
backend := &Backend{
|
||||
id: key,
|
||||
url: info.Url,
|
||||
parsedUrl: info.parsedUrl,
|
||||
secret: []byte(info.Secret),
|
||||
|
||||
allowHttp: info.parsedUrl.Scheme == "http",
|
||||
|
|
|
@ -241,6 +241,7 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map
|
|||
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
|
||||
id: id,
|
||||
url: u,
|
||||
parsedUrl: parsed,
|
||||
secret: []byte(secret),
|
||||
|
||||
allowHttp: parsed.Scheme == "http",
|
||||
|
|
38
grpc_backend.proto
Normal file
38
grpc_backend.proto
Normal file
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Standalone signaling server for the Nextcloud Spreed app.
|
||||
* Copyright (C) 2022 struktur AG
|
||||
*
|
||||
* @author Joachim Bauch <bauch@struktur.de>
|
||||
*
|
||||
* @license GNU AGPL version 3 or any later version
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling";
|
||||
|
||||
package signaling;
|
||||
|
||||
service RpcBackend {
|
||||
rpc GetSessionCount(GetSessionCountRequest) returns (GetSessionCountReply) {}
|
||||
}
|
||||
|
||||
message GetSessionCountRequest {
|
||||
string url = 1;
|
||||
}
|
||||
|
||||
message GetSessionCountReply {
|
||||
uint32 count = 1;
|
||||
}
|
|
@ -27,6 +27,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -58,6 +59,7 @@ func init() {
|
|||
}
|
||||
|
||||
type grpcClientImpl struct {
|
||||
RpcBackendClient
|
||||
RpcInternalClient
|
||||
RpcMcuClient
|
||||
RpcSessionsClient
|
||||
|
@ -65,6 +67,7 @@ type grpcClientImpl struct {
|
|||
|
||||
func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl {
|
||||
return &grpcClientImpl{
|
||||
RpcBackendClient: NewRpcBackendClient(conn),
|
||||
RpcInternalClient: NewRpcInternalClient(conn),
|
||||
RpcMcuClient: NewRpcMcuClient(conn),
|
||||
RpcSessionsClient: NewRpcSessionsClient(conn),
|
||||
|
@ -243,6 +246,22 @@ func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, strea
|
|||
return response.GetPublisherId(), response.GetProxyUrl(), net.ParseIP(response.GetIp()), nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetSessionCount(ctx context.Context, u *url.URL) (uint32, error) {
|
||||
statsGrpcClientCalls.WithLabelValues("GetSessionCount").Inc()
|
||||
// TODO: Remove debug logging
|
||||
log.Printf("Get session count for %s on %s", u, c.Target())
|
||||
response, err := c.impl.GetSessionCount(ctx, &GetSessionCountRequest{
|
||||
Url: u.String(),
|
||||
}, grpc.WaitForReady(true))
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
return 0, nil
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return response.GetCount(), nil
|
||||
}
|
||||
|
||||
type GrpcClients struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/dlintw/goconf"
|
||||
|
@ -54,6 +55,7 @@ func init() {
|
|||
}
|
||||
|
||||
type GrpcServer struct {
|
||||
UnimplementedRpcBackendServer
|
||||
UnimplementedRpcInternalServer
|
||||
UnimplementedRpcMcuServer
|
||||
UnimplementedRpcSessionsServer
|
||||
|
@ -86,6 +88,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) {
|
|||
listener: listener,
|
||||
serverId: GrpcServerId,
|
||||
}
|
||||
RegisterRpcBackendServer(conn, result)
|
||||
RegisterRpcInternalServer(conn, result)
|
||||
RegisterRpcSessionsServer(conn, result)
|
||||
RegisterRpcMcuServer(conn, result)
|
||||
|
@ -189,3 +192,21 @@ func (s *GrpcServer) GetServerId(ctx context.Context, request *GetServerIdReques
|
|||
ServerId: s.serverId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCountRequest) (*GetSessionCountReply, error) {
|
||||
statsGrpcServerCalls.WithLabelValues("SessionCount").Inc()
|
||||
|
||||
u, err := url.Parse(request.Url)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "invalid url")
|
||||
}
|
||||
|
||||
backend := s.hub.backend.GetBackend(u)
|
||||
if backend == nil {
|
||||
return nil, status.Error(codes.NotFound, "no such backend")
|
||||
}
|
||||
|
||||
return &GetSessionCountReply{
|
||||
Count: uint32(backend.Len()),
|
||||
}, nil
|
||||
}
|
||||
|
|
32
hub.go
32
hub.go
|
@ -754,6 +754,38 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
|
|||
return
|
||||
}
|
||||
|
||||
if limit := uint32(backend.Limit()); limit > 0 && h.rpcClients != nil {
|
||||
totalCount := uint32(backend.Len())
|
||||
var wg sync.WaitGroup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
for _, client := range h.rpcClients.GetClients() {
|
||||
wg.Add(1)
|
||||
go func(c *GrpcClient) {
|
||||
defer wg.Done()
|
||||
|
||||
count, err := c.GetSessionCount(ctx, backend.ParsedUrl())
|
||||
if err != nil {
|
||||
log.Printf("Received error while getting session count for %s from %s: %s", backend.Url(), c.Target(), err)
|
||||
return
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
log.Printf("%d sessions connected for %s on %s", count, backend.Url(), c.Target())
|
||||
atomic.AddUint32(&totalCount, count)
|
||||
}
|
||||
}(client)
|
||||
}
|
||||
wg.Wait()
|
||||
if totalCount > limit {
|
||||
backend.RemoveSession(session)
|
||||
log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), SessionLimitExceeded)
|
||||
session.Close()
|
||||
client.SendMessage(message.NewWrappedErrorServerMessage(SessionLimitExceeded))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
if !client.IsConnected() {
|
||||
// Client disconnected while waiting for backend response.
|
||||
|
|
76
hub_test.go
76
hub_test.go
|
@ -158,7 +158,7 @@ func CreateHubWithMultipleBackendsForTest(t *testing.T) (*Hub, AsyncEvents, *mux
|
|||
return h, events, r, server
|
||||
}
|
||||
|
||||
func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, *Hub, *httptest.Server, *httptest.Server) {
|
||||
func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, *Hub, *mux.Router, *mux.Router, *httptest.Server, *httptest.Server) {
|
||||
r1 := mux.NewRouter()
|
||||
registerBackendHandler(t, r1)
|
||||
|
||||
|
@ -237,11 +237,12 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
|
|||
WaitForHub(ctx, t, h2)
|
||||
})
|
||||
|
||||
return h1, h2, server1, server2
|
||||
return h1, h2, r1, r2, server1, server2
|
||||
}
|
||||
|
||||
func CreateClusteredHubsForTest(t *testing.T) (*Hub, *Hub, *httptest.Server, *httptest.Server) {
|
||||
return CreateClusteredHubsForTestWithConfig(t, getTestConfig)
|
||||
h1, h2, _, _, server1, server2 := CreateClusteredHubsForTestWithConfig(t, getTestConfig)
|
||||
return h1, h2, server1, server2
|
||||
}
|
||||
|
||||
func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
|
||||
|
@ -750,7 +751,16 @@ func TestClientHelloAllowAll(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClientHelloSessionLimit(t *testing.T) {
|
||||
hub, _, router, server := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
|
||||
for _, subtest := range clusteredTests {
|
||||
t.Run(subtest, func(t *testing.T) {
|
||||
var hub1 *Hub
|
||||
var hub2 *Hub
|
||||
var server1 *httptest.Server
|
||||
var server2 *httptest.Server
|
||||
|
||||
if isLocalTest(t) {
|
||||
var router1 *mux.Router
|
||||
hub1, _, router1, server1 = CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
|
||||
config, err := getTestConfig(server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -769,16 +779,54 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
|||
return config, nil
|
||||
})
|
||||
|
||||
registerBackendHandlerUrl(t, router, "/one")
|
||||
registerBackendHandlerUrl(t, router, "/two")
|
||||
registerBackendHandlerUrl(t, router1, "/one")
|
||||
registerBackendHandlerUrl(t, router1, "/two")
|
||||
|
||||
client := NewTestClient(t, server, hub)
|
||||
hub2 = hub1
|
||||
server2 = server1
|
||||
} else {
|
||||
var router1 *mux.Router
|
||||
var router2 *mux.Router
|
||||
hub1, hub2, router1, router2, server1, server2 = CreateClusteredHubsForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
|
||||
// Make sure all backends use the same server
|
||||
if server1 == nil {
|
||||
server1 = server
|
||||
} else {
|
||||
server = server1
|
||||
}
|
||||
|
||||
config, err := getTestConfig(server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.RemoveOption("backend", "allowed")
|
||||
config.RemoveOption("backend", "secret")
|
||||
config.AddOption("backend", "backends", "backend1, backend2")
|
||||
|
||||
config.AddOption("backend1", "url", server.URL+"/one")
|
||||
config.AddOption("backend1", "secret", string(testBackendSecret))
|
||||
config.AddOption("backend1", "sessionlimit", "1")
|
||||
|
||||
config.AddOption("backend2", "url", server.URL+"/two")
|
||||
config.AddOption("backend2", "secret", string(testBackendSecret))
|
||||
return config, nil
|
||||
})
|
||||
|
||||
registerBackendHandlerUrl(t, router1, "/one")
|
||||
registerBackendHandlerUrl(t, router1, "/two")
|
||||
|
||||
registerBackendHandlerUrl(t, router2, "/one")
|
||||
registerBackendHandlerUrl(t, router2, "/two")
|
||||
}
|
||||
|
||||
client := NewTestClient(t, server1, hub1)
|
||||
defer client.CloseWithBye()
|
||||
|
||||
params1 := TestBackendClientAuthParams{
|
||||
UserId: testDefaultUserId,
|
||||
}
|
||||
if err := client.SendHelloParams(server.URL+"/one", "client", params1); err != nil {
|
||||
if err := client.SendHelloParams(server1.URL+"/one", "client", params1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -797,13 +845,13 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
|||
}
|
||||
|
||||
// The second client can't connect as it would exceed the session limit.
|
||||
client2 := NewTestClient(t, server, hub)
|
||||
client2 := NewTestClient(t, server2, hub2)
|
||||
defer client2.CloseWithBye()
|
||||
|
||||
params2 := TestBackendClientAuthParams{
|
||||
UserId: testDefaultUserId + "2",
|
||||
}
|
||||
if err := client2.SendHelloParams(server.URL+"/one", "client", params2); err != nil {
|
||||
if err := client2.SendHelloParams(server1.URL+"/one", "client", params2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -819,7 +867,7 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
|||
}
|
||||
|
||||
// The client can connect to a different backend.
|
||||
if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil {
|
||||
if err := client2.SendHelloParams(server1.URL+"/two", "client", params2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -840,13 +888,13 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
|||
t.Error(err)
|
||||
}
|
||||
|
||||
client3 := NewTestClient(t, server, hub)
|
||||
client3 := NewTestClient(t, server2, hub2)
|
||||
defer client3.CloseWithBye()
|
||||
|
||||
params3 := TestBackendClientAuthParams{
|
||||
UserId: testDefaultUserId + "3",
|
||||
}
|
||||
if err := client3.SendHelloParams(server.URL+"/one", "client", params3); err != nil {
|
||||
if err := client3.SendHelloParams(server1.URL+"/one", "client", params3); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -860,6 +908,8 @@ func TestClientHelloSessionLimit(t *testing.T) {
|
|||
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionIdsUnordered(t *testing.T) {
|
||||
|
|
Loading…
Reference in a new issue