From deaa17acc5490938e825dc2f993d69709ec2f5b8 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Wed, 13 Jul 2022 11:52:20 +0200 Subject: [PATCH] Implement per-backend session limit for clusters. --- Makefile | 1 + backend_configuration.go | 27 ++++- backend_storage_etcd.go | 7 +- backend_storage_static.go | 7 +- grpc_backend.proto | 38 ++++++ grpc_client.go | 19 +++ grpc_server.go | 21 ++++ hub.go | 32 +++++ hub_test.go | 240 +++++++++++++++++++++++--------------- 9 files changed, 287 insertions(+), 105 deletions(-) create mode 100644 grpc_backend.proto diff --git a/Makefile b/Makefile index 27c9d48..b9cc4f6 100644 --- a/Makefile +++ b/Makefile @@ -119,6 +119,7 @@ common_easyjson: \ api_signaling_easyjson.go common_proto: \ + grpc_backend.pb.go \ grpc_internal.pb.go \ grpc_mcu.pb.go \ grpc_sessions.pb.go diff --git a/backend_configuration.go b/backend_configuration.go index bc69ff7..fa6b45b 100644 --- a/backend_configuration.go +++ b/backend_configuration.go @@ -42,10 +42,11 @@ var ( ) type Backend struct { - id string - url string - secret []byte - compat bool + id string + url string + parsedUrl *url.URL + secret []byte + compat bool allowHttp bool @@ -80,6 +81,24 @@ func (b *Backend) IsUrlAllowed(u *url.URL) bool { } } +func (b *Backend) Url() string { + return b.url +} + +func (b *Backend) ParsedUrl() *url.URL { + return b.parsedUrl +} + +func (b *Backend) Limit() int { + return int(b.sessionLimit) +} + +func (b *Backend) Len() int { + b.sessionsLock.Lock() + defer b.sessionsLock.Unlock() + return len(b.sessions) +} + func (b *Backend) AddSession(session Session) error { if session.ClientType() == HelloClientTypeInternal || session.ClientType() == HelloClientTypeVirtual { // Internal and virtual sessions are not counting to the limit. diff --git a/backend_storage_etcd.go b/backend_storage_etcd.go index e71cda9..a33b216 100644 --- a/backend_storage_etcd.go +++ b/backend_storage_etcd.go @@ -163,9 +163,10 @@ func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data } backend := &Backend{ - id: key, - url: info.Url, - secret: []byte(info.Secret), + id: key, + url: info.Url, + parsedUrl: info.parsedUrl, + secret: []byte(info.Secret), allowHttp: info.parsedUrl.Scheme == "http", diff --git a/backend_storage_static.go b/backend_storage_static.go index e062e60..144f039 100644 --- a/backend_storage_static.go +++ b/backend_storage_static.go @@ -239,9 +239,10 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map } hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{ - id: id, - url: u, - secret: []byte(secret), + id: id, + url: u, + parsedUrl: parsed, + secret: []byte(secret), allowHttp: parsed.Scheme == "http", diff --git a/grpc_backend.proto b/grpc_backend.proto new file mode 100644 index 0000000..f667f12 --- /dev/null +++ b/grpc_backend.proto @@ -0,0 +1,38 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2022 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + syntax = "proto3"; + + option go_package = "github.com/strukturag/nextcloud-spreed-signaling;signaling"; + + package signaling; + + service RpcBackend { + rpc GetSessionCount(GetSessionCountRequest) returns (GetSessionCountReply) {} + } + + message GetSessionCountRequest { + string url = 1; + } + + message GetSessionCountReply { + uint32 count = 1; + } diff --git a/grpc_client.go b/grpc_client.go index f31b7db..18ef04a 100644 --- a/grpc_client.go +++ b/grpc_client.go @@ -27,6 +27,7 @@ import ( "fmt" "log" "net" + "net/url" "strings" "sync" "sync/atomic" @@ -58,6 +59,7 @@ func init() { } type grpcClientImpl struct { + RpcBackendClient RpcInternalClient RpcMcuClient RpcSessionsClient @@ -65,6 +67,7 @@ type grpcClientImpl struct { func newGrpcClientImpl(conn grpc.ClientConnInterface) *grpcClientImpl { return &grpcClientImpl{ + RpcBackendClient: NewRpcBackendClient(conn), RpcInternalClient: NewRpcInternalClient(conn), RpcMcuClient: NewRpcMcuClient(conn), RpcSessionsClient: NewRpcSessionsClient(conn), @@ -243,6 +246,22 @@ func (c *GrpcClient) GetPublisherId(ctx context.Context, sessionId string, strea return response.GetPublisherId(), response.GetProxyUrl(), net.ParseIP(response.GetIp()), nil } +func (c *GrpcClient) GetSessionCount(ctx context.Context, u *url.URL) (uint32, error) { + statsGrpcClientCalls.WithLabelValues("GetSessionCount").Inc() + // TODO: Remove debug logging + log.Printf("Get session count for %s on %s", u, c.Target()) + response, err := c.impl.GetSessionCount(ctx, &GetSessionCountRequest{ + Url: u.String(), + }, grpc.WaitForReady(true)) + if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + return 0, nil + } else if err != nil { + return 0, err + } + + return response.GetCount(), nil +} + type GrpcClients struct { mu sync.RWMutex diff --git a/grpc_server.go b/grpc_server.go index 2631b19..97b6368 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -29,6 +29,7 @@ import ( "fmt" "log" "net" + "net/url" "os" "github.com/dlintw/goconf" @@ -54,6 +55,7 @@ func init() { } type GrpcServer struct { + UnimplementedRpcBackendServer UnimplementedRpcInternalServer UnimplementedRpcMcuServer UnimplementedRpcSessionsServer @@ -86,6 +88,7 @@ func NewGrpcServer(config *goconf.ConfigFile) (*GrpcServer, error) { listener: listener, serverId: GrpcServerId, } + RegisterRpcBackendServer(conn, result) RegisterRpcInternalServer(conn, result) RegisterRpcSessionsServer(conn, result) RegisterRpcMcuServer(conn, result) @@ -189,3 +192,21 @@ func (s *GrpcServer) GetServerId(ctx context.Context, request *GetServerIdReques ServerId: s.serverId, }, nil } + +func (s *GrpcServer) GetSessionCount(ctx context.Context, request *GetSessionCountRequest) (*GetSessionCountReply, error) { + statsGrpcServerCalls.WithLabelValues("SessionCount").Inc() + + u, err := url.Parse(request.Url) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "invalid url") + } + + backend := s.hub.backend.GetBackend(u) + if backend == nil { + return nil, status.Error(codes.NotFound, "no such backend") + } + + return &GetSessionCountReply{ + Count: uint32(backend.Len()), + }, nil +} diff --git a/hub.go b/hub.go index 2f34d82..c5194a5 100644 --- a/hub.go +++ b/hub.go @@ -754,6 +754,38 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B return } + if limit := uint32(backend.Limit()); limit > 0 && h.rpcClients != nil { + totalCount := uint32(backend.Len()) + var wg sync.WaitGroup + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + for _, client := range h.rpcClients.GetClients() { + wg.Add(1) + go func(c *GrpcClient) { + defer wg.Done() + + count, err := c.GetSessionCount(ctx, backend.ParsedUrl()) + if err != nil { + log.Printf("Received error while getting session count for %s from %s: %s", backend.Url(), c.Target(), err) + return + } + + if count > 0 { + log.Printf("%d sessions connected for %s on %s", count, backend.Url(), c.Target()) + atomic.AddUint32(&totalCount, count) + } + }(client) + } + wg.Wait() + if totalCount > limit { + backend.RemoveSession(session) + log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), SessionLimitExceeded) + session.Close() + client.SendMessage(message.NewWrappedErrorServerMessage(SessionLimitExceeded)) + return + } + } + h.mu.Lock() if !client.IsConnected() { // Client disconnected while waiting for backend response. diff --git a/hub_test.go b/hub_test.go index bb0fbe8..535328a 100644 --- a/hub_test.go +++ b/hub_test.go @@ -158,7 +158,7 @@ func CreateHubWithMultipleBackendsForTest(t *testing.T) (*Hub, AsyncEvents, *mux return h, events, r, server } -func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, *Hub, *httptest.Server, *httptest.Server) { +func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*httptest.Server) (*goconf.ConfigFile, error)) (*Hub, *Hub, *mux.Router, *mux.Router, *httptest.Server, *httptest.Server) { r1 := mux.NewRouter() registerBackendHandler(t, r1) @@ -237,11 +237,12 @@ func CreateClusteredHubsForTestWithConfig(t *testing.T, getConfigFunc func(*http WaitForHub(ctx, t, h2) }) - return h1, h2, server1, server2 + return h1, h2, r1, r2, server1, server2 } func CreateClusteredHubsForTest(t *testing.T) (*Hub, *Hub, *httptest.Server, *httptest.Server) { - return CreateClusteredHubsForTestWithConfig(t, getTestConfig) + h1, h2, _, _, server1, server2 := CreateClusteredHubsForTestWithConfig(t, getTestConfig) + return h1, h2, server1, server2 } func WaitForHub(ctx context.Context, t *testing.T, h *Hub) { @@ -750,115 +751,164 @@ func TestClientHelloAllowAll(t *testing.T) { } func TestClientHelloSessionLimit(t *testing.T) { - hub, _, router, server := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { - config, err := getTestConfig(server) - if err != nil { - return nil, err - } + for _, subtest := range clusteredTests { + t.Run(subtest, func(t *testing.T) { + var hub1 *Hub + var hub2 *Hub + var server1 *httptest.Server + var server2 *httptest.Server - config.RemoveOption("backend", "allowed") - config.RemoveOption("backend", "secret") - config.AddOption("backend", "backends", "backend1, backend2") + if isLocalTest(t) { + var router1 *mux.Router + hub1, _, router1, server1 = CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { + config, err := getTestConfig(server) + if err != nil { + return nil, err + } - config.AddOption("backend1", "url", server.URL+"/one") - config.AddOption("backend1", "secret", string(testBackendSecret)) - config.AddOption("backend1", "sessionlimit", "1") + config.RemoveOption("backend", "allowed") + config.RemoveOption("backend", "secret") + config.AddOption("backend", "backends", "backend1, backend2") - config.AddOption("backend2", "url", server.URL+"/two") - config.AddOption("backend2", "secret", string(testBackendSecret)) - return config, nil - }) + config.AddOption("backend1", "url", server.URL+"/one") + config.AddOption("backend1", "secret", string(testBackendSecret)) + config.AddOption("backend1", "sessionlimit", "1") - registerBackendHandlerUrl(t, router, "/one") - registerBackendHandlerUrl(t, router, "/two") + config.AddOption("backend2", "url", server.URL+"/two") + config.AddOption("backend2", "secret", string(testBackendSecret)) + return config, nil + }) - client := NewTestClient(t, server, hub) - defer client.CloseWithBye() + registerBackendHandlerUrl(t, router1, "/one") + registerBackendHandlerUrl(t, router1, "/two") - params1 := TestBackendClientAuthParams{ - UserId: testDefaultUserId, - } - if err := client.SendHelloParams(server.URL+"/one", "client", params1); err != nil { - t.Fatal(err) - } + hub2 = hub1 + server2 = server1 + } else { + var router1 *mux.Router + var router2 *mux.Router + hub1, hub2, router1, router2, server1, server2 = CreateClusteredHubsForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { + // Make sure all backends use the same server + if server1 == nil { + server1 = server + } else { + server = server1 + } - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() + config, err := getTestConfig(server) + if err != nil { + return nil, err + } - if hello, err := client.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - } + config.RemoveOption("backend", "allowed") + config.RemoveOption("backend", "secret") + config.AddOption("backend", "backends", "backend1, backend2") - // The second client can't connect as it would exceed the session limit. - client2 := NewTestClient(t, server, hub) - defer client2.CloseWithBye() + config.AddOption("backend1", "url", server.URL+"/one") + config.AddOption("backend1", "secret", string(testBackendSecret)) + config.AddOption("backend1", "sessionlimit", "1") - params2 := TestBackendClientAuthParams{ - UserId: testDefaultUserId + "2", - } - if err := client2.SendHelloParams(server.URL+"/one", "client", params2); err != nil { - t.Fatal(err) - } + config.AddOption("backend2", "url", server.URL+"/two") + config.AddOption("backend2", "secret", string(testBackendSecret)) + return config, nil + }) - msg, err := client2.RunUntilMessage(ctx) - if err != nil { - t.Error(err) - } else { - if msg.Type != "error" || msg.Error == nil { - t.Errorf("Expected error message, got %+v", msg) - } else if msg.Error.Code != "session_limit_exceeded" { - t.Errorf("Expected error \"session_limit_exceeded\", got %+v", msg.Error.Code) - } - } + registerBackendHandlerUrl(t, router1, "/one") + registerBackendHandlerUrl(t, router1, "/two") - // The client can connect to a different backend. - if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { - t.Fatal(err) - } + registerBackendHandlerUrl(t, router2, "/one") + registerBackendHandlerUrl(t, router2, "/two") + } - if hello, err := client2.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId+"2" { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } - } + client := NewTestClient(t, server1, hub1) + defer client.CloseWithBye() - // If the first client disconnects (and releases the session), a new one can connect. - client.CloseWithBye() - if err := client.WaitForClientRemoved(ctx); err != nil { - t.Error(err) - } + params1 := TestBackendClientAuthParams{ + UserId: testDefaultUserId, + } + if err := client.SendHelloParams(server1.URL+"/one", "client", params1); err != nil { + t.Fatal(err) + } - client3 := NewTestClient(t, server, hub) - defer client3.CloseWithBye() + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() - params3 := TestBackendClientAuthParams{ - UserId: testDefaultUserId + "3", - } - if err := client3.SendHelloParams(server.URL+"/one", "client", params3); err != nil { - t.Fatal(err) - } + if hello, err := client.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + } - if hello, err := client3.RunUntilHello(ctx); err != nil { - t.Error(err) - } else { - if hello.Hello.UserId != testDefaultUserId+"3" { - t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"3", hello.Hello) - } - if hello.Hello.SessionId == "" { - t.Errorf("Expected session id, got %+v", hello.Hello) - } + // The second client can't connect as it would exceed the session limit. + client2 := NewTestClient(t, server2, hub2) + defer client2.CloseWithBye() + + params2 := TestBackendClientAuthParams{ + UserId: testDefaultUserId + "2", + } + if err := client2.SendHelloParams(server1.URL+"/one", "client", params2); err != nil { + t.Fatal(err) + } + + msg, err := client2.RunUntilMessage(ctx) + if err != nil { + t.Error(err) + } else { + if msg.Type != "error" || msg.Error == nil { + t.Errorf("Expected error message, got %+v", msg) + } else if msg.Error.Code != "session_limit_exceeded" { + t.Errorf("Expected error \"session_limit_exceeded\", got %+v", msg.Error.Code) + } + } + + // The client can connect to a different backend. + if err := client2.SendHelloParams(server1.URL+"/two", "client", params2); err != nil { + t.Fatal(err) + } + + if hello, err := client2.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId+"2" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + } + + // If the first client disconnects (and releases the session), a new one can connect. + client.CloseWithBye() + if err := client.WaitForClientRemoved(ctx); err != nil { + t.Error(err) + } + + client3 := NewTestClient(t, server2, hub2) + defer client3.CloseWithBye() + + params3 := TestBackendClientAuthParams{ + UserId: testDefaultUserId + "3", + } + if err := client3.SendHelloParams(server1.URL+"/one", "client", params3); err != nil { + t.Fatal(err) + } + + if hello, err := client3.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId+"3" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"3", hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + } + }) } }