Implement per-backend session limit for clusters.

This commit is contained in:
Joachim Bauch 2022-07-13 11:52:20 +02:00
parent 12a8fa98d0
commit deaa17acc5
No known key found for this signature in database
GPG Key ID: 77C1D22D53E15F02
9 changed files with 287 additions and 105 deletions

View File

@ -119,6 +119,7 @@ common_easyjson: \
api_signaling_easyjson.go
common_proto: \
grpc_backend.pb.go \
grpc_internal.pb.go \
grpc_mcu.pb.go \
grpc_sessions.pb.go

View File

@ -42,10 +42,11 @@ var (
)
type Backend struct {
id string
url string
secret []byte
compat bool
id string
url string
parsedUrl *url.URL
secret []byte
compat bool
allowHttp bool
@ -80,6 +81,24 @@ func (b *Backend) IsUrlAllowed(u *url.URL) bool {
}
}
func (b *Backend) Url() string {
return b.url
}
func (b *Backend) ParsedUrl() *url.URL {
return b.parsedUrl
}
func (b *Backend) Limit() int {
return int(b.sessionLimit)
}
func (b *Backend) Len() int {
b.sessionsLock.Lock()
defer b.sessionsLock.Unlock()
return len(b.sessions)
}
func (b *Backend) AddSession(session Session) error {
if session.ClientType() == HelloClientTypeInternal || session.ClientType() == HelloClientTypeVirtual {
// Internal and virtual sessions are not counting to the limit.

View File

@ -163,9 +163,10 @@ func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data
}
backend := &Backend{
id: key,
url: info.Url,
secret: []byte(info.Secret),
id: key,
url: info.Url,
parsedUrl: info.parsedUrl,
secret: []byte(info.Secret),
allowHttp: info.parsedUrl.Scheme == "http",

View File

@ -239,9 +239,10 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map
}
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
id: id,
url: u,
secret: []byte(secret),
id: id,
url: u,
parsedUrl: parsed,
secret: []byte(secret),
allowHttp: parsed.Scheme == "http",

38
grpc_backend.proto Normal file
View File

@ -0,0 +1,38 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
syntax = "proto3";
option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling";
package signaling;
service RpcBackend {
rpc GetSessionCount(GetSessionCountRequest) returns (GetSessionCountReply) {}
}
message GetSessionCountRequest {
string url = 1;
}
message GetSessionCountReply {
uint32 count = 1;
}

View File

@ -27,6 +27,7 @@ import (
"fmt"
"log"
"net"
"net/url"
"strings"
"sync"
"sync/atomic"
@ -58,6 +59,7 @@ func init() {
}
type grpcClientImpl struct {
RpcBackendClient
RpcInternalClient
RpcMcuClient
RpcSessionsClient
@ -65,6 +67,7 @@ type grpcClientImpl struct {
func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl {
return &grpcClientImpl{
RpcBackendClient: NewRpcBackendClient(conn),
RpcInternalClient: NewRpcInternalClient(conn),
RpcMcuClient: NewRpcMcuClient(conn),
RpcSessionsClient: NewRpcSessionsClient(conn),
@ -243,6 +246,22 @@ func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, strea
return response.GetPublisherId(), response.GetProxyUrl(), net.ParseIP(response.GetIp()), nil
}
func (c *GrpcClient) GetSessionCount(ctx context.Context, u *url.URL) (uint32, error) {
statsGrpcClientCalls.WithLabelValues("GetSessionCount").Inc()
// TODO: Remove debug logging
log.Printf("Get session count for %s on %s", u, c.Target())
response, err := c.impl.GetSessionCount(ctx, &GetSessionCountRequest{
Url: u.String(),
}, grpc.WaitForReady(true))
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return 0, nil
} else if err != nil {
return 0, err
}
return response.GetCount(), nil
}
type GrpcClients struct {
mu sync.RWMutex

View File

@ -29,6 +29,7 @@ import (
"fmt"
"log"
"net"
"net/url"
"os"
"github.com/dlintw/goconf"
@ -54,6 +55,7 @@ func init() {
}
type GrpcServer struct {
UnimplementedRpcBackendServer
UnimplementedRpcInternalServer
UnimplementedRpcMcuServer
UnimplementedRpcSessionsServer
@ -86,6 +88,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) {
listener: listener,
serverId: GrpcServerId,
}
RegisterRpcBackendServer(conn, result)
RegisterRpcInternalServer(conn, result)
RegisterRpcSessionsServer(conn, result)
RegisterRpcMcuServer(conn, result)
@ -189,3 +192,21 @@ func (s *GrpcServer) GetServerId(ctx context.Context, request *GetServerIdReques
ServerId: s.serverId,
}, nil
}
func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCountRequest) (*GetSessionCountReply, error) {
statsGrpcServerCalls.WithLabelValues("SessionCount").Inc()
u, err := url.Parse(request.Url)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid url")
}
backend := s.hub.backend.GetBackend(u)
if backend == nil {
return nil, status.Error(codes.NotFound, "no such backend")
}
return &GetSessionCountReply{
Count: uint32(backend.Len()),
}, nil
}

32
hub.go
View File

@ -754,6 +754,38 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
return
}
if limit := uint32(backend.Limit()); limit > 0 && h.rpcClients != nil {
totalCount := uint32(backend.Len())
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
for _, client := range h.rpcClients.GetClients() {
wg.Add(1)
go func(c *GrpcClient) {
defer wg.Done()
count, err := c.GetSessionCount(ctx, backend.ParsedUrl())
if err != nil {
log.Printf("Received error while getting session count for %s from %s: %s", backend.Url(), c.Target(), err)
return
}
if count > 0 {
log.Printf("%d sessions connected for %s on %s", count, backend.Url(), c.Target())
atomic.AddUint32(&totalCount, count)
}
}(client)
}
wg.Wait()
if totalCount > limit {
backend.RemoveSession(session)
log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), SessionLimitExceeded)
session.Close()
client.SendMessage(message.NewWrappedErrorServerMessage(SessionLimitExceeded))
return
}
}
h.mu.Lock()
if !client.IsConnected() {
// Client disconnected while waiting for backend response.

View File

@ -158,7 +158,7 @@ func CreateHubWithMultipleBackendsForTest(t *testing.T) (*Hub, AsyncEvents, *mux
return h, events, r, server
}
func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, *Hub, *httptest.Server, *httptest.Server) {
func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, *Hub, *mux.Router, *mux.Router, *httptest.Server, *httptest.Server) {
r1 := mux.NewRouter()
registerBackendHandler(t, r1)
@ -237,11 +237,12 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http
WaitForHub(ctx, t, h2)
})
return h1, h2, server1, server2
return h1, h2, r1, r2, server1, server2
}
func CreateClusteredHubsForTest(t *testing.T) (*Hub, *Hub, *httptest.Server, *httptest.Server) {
return CreateClusteredHubsForTestWithConfig(t, getTestConfig)
h1, h2, _, _, server1, server2 := CreateClusteredHubsForTestWithConfig(t, getTestConfig)
return h1, h2, server1, server2
}
func WaitForHub(ctx context.Context, t *testing.T, h *Hub) {
@ -750,115 +751,164 @@ func TestClientHelloAllowAll(t *testing.T) {
}
func TestClientHelloSessionLimit(t *testing.T) {
hub, _, router, server := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
config, err := getTestConfig(server)
if err != nil {
return nil, err
}
for _, subtest := range clusteredTests {
t.Run(subtest, func(t *testing.T) {
var hub1 *Hub
var hub2 *Hub
var server1 *httptest.Server
var server2 *httptest.Server
config.RemoveOption("backend", "allowed")
config.RemoveOption("backend", "secret")
config.AddOption("backend", "backends", "backend1, backend2")
if isLocalTest(t) {
var router1 *mux.Router
hub1, _, router1, server1 = CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
config, err := getTestConfig(server)
if err != nil {
return nil, err
}
config.AddOption("backend1", "url", server.URL+"/one")
config.AddOption("backend1", "secret", string(testBackendSecret))
config.AddOption("backend1", "sessionlimit", "1")
config.RemoveOption("backend", "allowed")
config.RemoveOption("backend", "secret")
config.AddOption("backend", "backends", "backend1, backend2")
config.AddOption("backend2", "url", server.URL+"/two")
config.AddOption("backend2", "secret", string(testBackendSecret))
return config, nil
})
config.AddOption("backend1", "url", server.URL+"/one")
config.AddOption("backend1", "secret", string(testBackendSecret))
config.AddOption("backend1", "sessionlimit", "1")
registerBackendHandlerUrl(t, router, "/one")
registerBackendHandlerUrl(t, router, "/two")
config.AddOption("backend2", "url", server.URL+"/two")
config.AddOption("backend2", "secret", string(testBackendSecret))
return config, nil
})
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
registerBackendHandlerUrl(t, router1, "/one")
registerBackendHandlerUrl(t, router1, "/two")
params1 := TestBackendClientAuthParams{
UserId: testDefaultUserId,
}
if err := client.SendHelloParams(server.URL+"/one", "client", params1); err != nil {
t.Fatal(err)
}
hub2 = hub1
server2 = server1
} else {
var router1 *mux.Router
var router2 *mux.Router
hub1, hub2, router1, router2, server1, server2 = CreateClusteredHubsForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
// Make sure all backends use the same server
if server1 == nil {
server1 = server
} else {
server = server1
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
config, err := getTestConfig(server)
if err != nil {
return nil, err
}
if hello, err := client.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
}
config.RemoveOption("backend", "allowed")
config.RemoveOption("backend", "secret")
config.AddOption("backend", "backends", "backend1, backend2")
// The second client can't connect as it would exceed the session limit.
client2 := NewTestClient(t, server, hub)
defer client2.CloseWithBye()
config.AddOption("backend1", "url", server.URL+"/one")
config.AddOption("backend1", "secret", string(testBackendSecret))
config.AddOption("backend1", "sessionlimit", "1")
params2 := TestBackendClientAuthParams{
UserId: testDefaultUserId + "2",
}
if err := client2.SendHelloParams(server.URL+"/one", "client", params2); err != nil {
t.Fatal(err)
}
config.AddOption("backend2", "url", server.URL+"/two")
config.AddOption("backend2", "secret", string(testBackendSecret))
return config, nil
})
msg, err := client2.RunUntilMessage(ctx)
if err != nil {
t.Error(err)
} else {
if msg.Type != "error" || msg.Error == nil {
t.Errorf("Expected error message, got %+v", msg)
} else if msg.Error.Code != "session_limit_exceeded" {
t.Errorf("Expected error \"session_limit_exceeded\", got %+v", msg.Error.Code)
}
}
registerBackendHandlerUrl(t, router1, "/one")
registerBackendHandlerUrl(t, router1, "/two")
// The client can connect to a different backend.
if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil {
t.Fatal(err)
}
registerBackendHandlerUrl(t, router2, "/one")
registerBackendHandlerUrl(t, router2, "/two")
}
if hello, err := client2.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId+"2" {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
}
client := NewTestClient(t, server1, hub1)
defer client.CloseWithBye()
// If the first client disconnects (and releases the session), a new one can connect.
client.CloseWithBye()
if err := client.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
params1 := TestBackendClientAuthParams{
UserId: testDefaultUserId,
}
if err := client.SendHelloParams(server1.URL+"/one", "client", params1); err != nil {
t.Fatal(err)
}
client3 := NewTestClient(t, server, hub)
defer client3.CloseWithBye()
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
params3 := TestBackendClientAuthParams{
UserId: testDefaultUserId + "3",
}
if err := client3.SendHelloParams(server.URL+"/one", "client", params3); err != nil {
t.Fatal(err)
}
if hello, err := client.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
}
if hello, err := client3.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId+"3" {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"3", hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
// The second client can't connect as it would exceed the session limit.
client2 := NewTestClient(t, server2, hub2)
defer client2.CloseWithBye()
params2 := TestBackendClientAuthParams{
UserId: testDefaultUserId + "2",
}
if err := client2.SendHelloParams(server1.URL+"/one", "client", params2); err != nil {
t.Fatal(err)
}
msg, err := client2.RunUntilMessage(ctx)
if err != nil {
t.Error(err)
} else {
if msg.Type != "error" || msg.Error == nil {
t.Errorf("Expected error message, got %+v", msg)
} else if msg.Error.Code != "session_limit_exceeded" {
t.Errorf("Expected error \"session_limit_exceeded\", got %+v", msg.Error.Code)
}
}
// The client can connect to a different backend.
if err := client2.SendHelloParams(server1.URL+"/two", "client", params2); err != nil {
t.Fatal(err)
}
if hello, err := client2.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId+"2" {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
}
// If the first client disconnects (and releases the session), a new one can connect.
client.CloseWithBye()
if err := client.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
client3 := NewTestClient(t, server2, hub2)
defer client3.CloseWithBye()
params3 := TestBackendClientAuthParams{
UserId: testDefaultUserId + "3",
}
if err := client3.SendHelloParams(server1.URL+"/one", "client", params3); err != nil {
t.Fatal(err)
}
if hello, err := client3.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId+"3" {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"3", hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
}
})
}
}