diff --git a/server.conf.in b/server.conf.in index e4b5dcb..ef9c9f2 100644 --- a/server.conf.in +++ b/server.conf.in @@ -82,6 +82,10 @@ connectionsperhost = 8 # same value as configured in the Nextcloud admin ui. #secret = the-shared-secret +# Limit the number of sessions that are allowed to connect to this backend. +# Omit or set to 0 to not limit the number of sessions. +#sessionlimit = 10 + #[another-backend] # URL of the Nextcloud instance #url = https://cloud.otherdomain.invalid diff --git a/src/signaling/api_signaling.go b/src/signaling/api_signaling.go index 8fdebe5..c43351b 100644 --- a/src/signaling/api_signaling.go +++ b/src/signaling/api_signaling.go @@ -104,6 +104,10 @@ func (m *ClientMessage) NewErrorServerMessage(e *Error) *ServerMessage { } func (m *ClientMessage) NewWrappedErrorServerMessage(e error) *ServerMessage { + if e, ok := e.(*Error); ok { + return m.NewErrorServerMessage(e) + } + return m.NewErrorServerMessage(NewError("internal_error", e.Error())) } diff --git a/src/signaling/backend_configuration.go b/src/signaling/backend_configuration.go index a1ecc35..8ff1536 100644 --- a/src/signaling/backend_configuration.go +++ b/src/signaling/backend_configuration.go @@ -26,15 +26,24 @@ import ( "net/url" "reflect" "strings" + "sync" "github.com/dlintw/goconf" ) +var ( + SessionLimitExceeded = NewError("session_limit_exceeded", "Too many sessions connected for this backend.") +) + type Backend struct { id string url string secret []byte compat bool + + sessionLimit uint64 + sessionsLock sync.Mutex + sessions map[string]bool } func (b *Backend) Id() string { @@ -49,6 +58,36 @@ func (b *Backend) IsCompat() bool { return b.compat } +func (b *Backend) AddSession(session Session) error { + if session.ClientType() == HelloClientTypeInternal || session.ClientType() == HelloClientTypeVirtual { + // Internal and virtual sessions are not counting to the limit. + return nil + } + + if b.sessionLimit == 0 { + // Not limited + return nil + } + + b.sessionsLock.Lock() + defer b.sessionsLock.Unlock() + if b.sessions == nil { + b.sessions = make(map[string]bool) + } else if uint64(len(b.sessions)) >= b.sessionLimit { + return SessionLimitExceeded + } + + b.sessions[session.PublicId()] = true + return nil +} + +func (b *Backend) RemoveSession(session Session) { + b.sessionsLock.Lock() + defer b.sessionsLock.Unlock() + + delete(b.sessions, session.PublicId()) +} + type BackendConfiguration struct { backends map[string][]*Backend @@ -61,6 +100,10 @@ type BackendConfiguration struct { func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, error) { allowAll, _ := config.GetBool("backend", "allowall") commonSecret, _ := config.GetString("backend", "secret") + sessionLimit, err := config.GetInt("backend", "sessionlimit") + if err != nil || sessionLimit < 0 { + sessionLimit = 0 + } backends := make(map[string][]*Backend) var compatBackend *Backend if allowAll { @@ -69,6 +112,11 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, id: "compat", secret: []byte(commonSecret), compat: true, + + sessionLimit: uint64(sessionLimit), + } + if sessionLimit > 0 { + log.Printf("Allow a maximum of %d sessions", sessionLimit) } } else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" { for host, configuredBackends := range getConfiguredHosts(backendIds, config) { @@ -98,6 +146,8 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, id: "compat", secret: []byte(commonSecret), compat: true, + + sessionLimit: uint64(sessionLimit), } hosts := make([]string, 0, len(allowMap)) for host := range allowMap { @@ -108,6 +158,9 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.") } log.Printf("Allowed backend hostnames: %s\n", hosts) + if sessionLimit > 0 { + log.Printf("Allow a maximum of %d sessions", sessionLimit) + } } } @@ -208,10 +261,20 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map continue } + sessionLimit, err := config.GetInt(id, "sessionlimit") + if err != nil || sessionLimit < 0 { + sessionLimit = 0 + } + if sessionLimit > 0 { + log.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit) + } + hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{ id: id, url: u, secret: []byte(secret), + + sessionLimit: uint64(sessionLimit), }) } diff --git a/src/signaling/backend_configuration_test.go b/src/signaling/backend_configuration_test.go index 1a468e1..4dce449 100644 --- a/src/signaling/backend_configuration_test.go +++ b/src/signaling/backend_configuration_test.go @@ -242,6 +242,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { new_config.AddOption("backend", "allowall", "false") new_config.AddOption("backend1", "url", "http://domain3.invalid") new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") + new_config.AddOption("backend1", "sessionlimit", "10") new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") n_cfg, err := NewBackendConfiguration(new_config) @@ -251,6 +252,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) { original_config.RemoveOption("backend1", "url") original_config.AddOption("backend1", "url", "http://domain3.invalid") + original_config.AddOption("backend1", "sessionlimit", "10") o_cfg.Reload(original_config) if !reflect.DeepEqual(n_cfg, o_cfg) { @@ -310,6 +312,7 @@ func TestBackendReloadAddBackend(t *testing.T) { new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1") new_config.AddOption("backend2", "url", "http://domain2.invalid") new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") + new_config.AddOption("backend2", "sessionlimit", "10") n_cfg, err := NewBackendConfiguration(new_config) if err != nil { t.Fatal(err) @@ -319,6 +322,7 @@ func TestBackendReloadAddBackend(t *testing.T) { original_config.AddOption("backend", "backends", "backend1, backend2") original_config.AddOption("backend2", "url", "http://domain2.invalid") original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2") + original_config.AddOption("backend2", "sessionlimit", "10") o_cfg.Reload(original_config) if !reflect.DeepEqual(n_cfg, o_cfg) { diff --git a/src/signaling/clientsession.go b/src/signaling/clientsession.go index 33bb671..d2bc979 100644 --- a/src/signaling/clientsession.go +++ b/src/signaling/clientsession.go @@ -329,6 +329,7 @@ func (s *ClientSession) closeAndWait(wait bool) { s.virtualSessions = nil s.releaseMcuObjects() s.clearClientLocked(nil) + s.backend.RemoveSession(s) if atomic.CompareAndSwapInt32(&s.running, 1, 0) { s.stopRun <- true // Only wait if called from outside the Session goroutine. diff --git a/src/signaling/hub.go b/src/signaling/hub.go index 5c7f795..112f463 100644 --- a/src/signaling/hub.go +++ b/src/signaling/hub.go @@ -705,6 +705,13 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B return } + if err := backend.AddSession(session); err != nil { + log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), err) + session.Close() + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + h.mu.Lock() if !client.IsConnected() { // Client disconnected while waiting for backend response. diff --git a/src/signaling/hub_test.go b/src/signaling/hub_test.go index 0b03d56..90d4a36 100644 --- a/src/signaling/hub_test.go +++ b/src/signaling/hub_test.go @@ -431,6 +431,120 @@ func TestClientHelloAllowAll(t *testing.T) { } } +func TestClientHelloSessionLimit(t *testing.T) { + hub, _, router, server, shutdown := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { + config, err := getTestConfig(server) + if err != nil { + return nil, err + } + + config.RemoveOption("backend", "allowed") + config.RemoveOption("backend", "secret") + config.AddOption("backend", "backends", "backend1, backend2") + + config.AddOption("backend1", "url", server.URL+"/one") + config.AddOption("backend1", "secret", string(testBackendSecret)) + config.AddOption("backend1", "sessionlimit", "1") + + config.AddOption("backend2", "url", server.URL+"/two") + config.AddOption("backend2", "secret", string(testBackendSecret)) + return config, nil + }) + defer shutdown() + + registerBackendHandlerUrl(t, router, "/one") + registerBackendHandlerUrl(t, router, "/two") + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + params1 := TestBackendClientAuthParams{ + UserId: testDefaultUserId, + } + if err := client.SendHelloParams(server.URL+"/one", "client", params1); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if hello, err := client.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + } + + // The second client can't connect as it would exceed the session limit. + client2 := NewTestClient(t, server, hub) + defer client2.CloseWithBye() + + params2 := TestBackendClientAuthParams{ + UserId: testDefaultUserId + "2", + } + if err := client2.SendHelloParams(server.URL+"/one", "client", params2); err != nil { + t.Fatal(err) + } + + msg, err := client2.RunUntilMessage(ctx) + if err != nil { + t.Error(err) + } else { + if msg.Type != "error" || msg.Error == nil { + t.Errorf("Expected error message, got %+v", msg) + } else if msg.Error.Code != "session_limit_exceeded" { + t.Errorf("Expected error \"session_limit_exceeded\", got %+v", msg.Error.Code) + } + } + + // The client can connect to a different backend. + if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil { + t.Fatal(err) + } + + if hello, err := client2.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId+"2" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + } + + // If the first client disconnects (and releases the session), a new one can connect. + client.CloseWithBye() + if err := client.WaitForClientRemoved(ctx); err != nil { + t.Error(err) + } + + client3 := NewTestClient(t, server, hub) + defer client3.CloseWithBye() + + params3 := TestBackendClientAuthParams{ + UserId: testDefaultUserId + "3", + } + if err := client3.SendHelloParams(server.URL+"/one", "client", params3); err != nil { + t.Fatal(err) + } + + if hello, err := client3.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId+"3" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"3", hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + } +} + func TestSessionIdsUnordered(t *testing.T) { hub, _, _, server, shutdown := CreateHubForTest(t) defer shutdown()