mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-06-10 09:52:12 +02:00
Support limiting number of sessions per backend.
This commit is contained in:
parent
b62beb2d31
commit
a3e34143c5
|
@ -82,6 +82,10 @@ connectionsperhost = 8
|
||||||
# same value as configured in the Nextcloud admin ui.
|
# same value as configured in the Nextcloud admin ui.
|
||||||
#secret = the-shared-secret
|
#secret = the-shared-secret
|
||||||
|
|
||||||
|
# Limit the number of sessions that are allowed to connect to this backend.
|
||||||
|
# Omit or set to 0 to not limit the number of sessions.
|
||||||
|
#sessionlimit = 10
|
||||||
|
|
||||||
#[another-backend]
|
#[another-backend]
|
||||||
# URL of the Nextcloud instance
|
# URL of the Nextcloud instance
|
||||||
#url = https://cloud.otherdomain.invalid
|
#url = https://cloud.otherdomain.invalid
|
||||||
|
|
|
@ -26,15 +26,24 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/dlintw/goconf"
|
"github.com/dlintw/goconf"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
SessionLimitExceeded = NewError("session_limit_exceeded", "Too many sessions connected for this backend.")
|
||||||
|
)
|
||||||
|
|
||||||
type Backend struct {
|
type Backend struct {
|
||||||
id string
|
id string
|
||||||
url string
|
url string
|
||||||
secret []byte
|
secret []byte
|
||||||
compat bool
|
compat bool
|
||||||
|
|
||||||
|
sessionLimit uint64
|
||||||
|
sessionsLock sync.Mutex
|
||||||
|
sessions map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) Id() string {
|
func (b *Backend) Id() string {
|
||||||
|
@ -49,6 +58,36 @@ func (b *Backend) IsCompat() bool {
|
||||||
return b.compat
|
return b.compat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Backend) AddSession(session Session) error {
|
||||||
|
if session.ClientType() == HelloClientTypeInternal || session.ClientType() == HelloClientTypeVirtual {
|
||||||
|
// Internal and virtual sessions are not counting to the limit.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.sessionLimit == 0 {
|
||||||
|
// Not limited
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.sessionsLock.Lock()
|
||||||
|
defer b.sessionsLock.Unlock()
|
||||||
|
if b.sessions == nil {
|
||||||
|
b.sessions = make(map[string]bool)
|
||||||
|
} else if uint64(len(b.sessions)) >= b.sessionLimit {
|
||||||
|
return SessionLimitExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
b.sessions[session.PublicId()] = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Backend) RemoveSession(session Session) {
|
||||||
|
b.sessionsLock.Lock()
|
||||||
|
defer b.sessionsLock.Unlock()
|
||||||
|
|
||||||
|
delete(b.sessions, session.PublicId())
|
||||||
|
}
|
||||||
|
|
||||||
type BackendConfiguration struct {
|
type BackendConfiguration struct {
|
||||||
backends map[string][]*Backend
|
backends map[string][]*Backend
|
||||||
|
|
||||||
|
@ -61,6 +100,10 @@ type BackendConfiguration struct {
|
||||||
func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, error) {
|
func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, error) {
|
||||||
allowAll, _ := config.GetBool("backend", "allowall")
|
allowAll, _ := config.GetBool("backend", "allowall")
|
||||||
commonSecret, _ := config.GetString("backend", "secret")
|
commonSecret, _ := config.GetString("backend", "secret")
|
||||||
|
sessionLimit, err := config.GetInt("backend", "sessionlimit")
|
||||||
|
if err != nil || sessionLimit < 0 {
|
||||||
|
sessionLimit = 0
|
||||||
|
}
|
||||||
backends := make(map[string][]*Backend)
|
backends := make(map[string][]*Backend)
|
||||||
var compatBackend *Backend
|
var compatBackend *Backend
|
||||||
if allowAll {
|
if allowAll {
|
||||||
|
@ -69,6 +112,11 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration,
|
||||||
id: "compat",
|
id: "compat",
|
||||||
secret: []byte(commonSecret),
|
secret: []byte(commonSecret),
|
||||||
compat: true,
|
compat: true,
|
||||||
|
|
||||||
|
sessionLimit: uint64(sessionLimit),
|
||||||
|
}
|
||||||
|
if sessionLimit > 0 {
|
||||||
|
log.Printf("Allow a maximum of %d sessions", sessionLimit)
|
||||||
}
|
}
|
||||||
} else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
|
} else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
|
||||||
for host, configuredBackends := range getConfiguredHosts(backendIds, config) {
|
for host, configuredBackends := range getConfiguredHosts(backendIds, config) {
|
||||||
|
@ -98,6 +146,8 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration,
|
||||||
id: "compat",
|
id: "compat",
|
||||||
secret: []byte(commonSecret),
|
secret: []byte(commonSecret),
|
||||||
compat: true,
|
compat: true,
|
||||||
|
|
||||||
|
sessionLimit: uint64(sessionLimit),
|
||||||
}
|
}
|
||||||
hosts := make([]string, 0, len(allowMap))
|
hosts := make([]string, 0, len(allowMap))
|
||||||
for host := range allowMap {
|
for host := range allowMap {
|
||||||
|
@ -108,6 +158,9 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration,
|
||||||
log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.")
|
log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.")
|
||||||
}
|
}
|
||||||
log.Printf("Allowed backend hostnames: %s\n", hosts)
|
log.Printf("Allowed backend hostnames: %s\n", hosts)
|
||||||
|
if sessionLimit > 0 {
|
||||||
|
log.Printf("Allow a maximum of %d sessions", sessionLimit)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,10 +261,20 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionLimit, err := config.GetInt(id, "sessionlimit")
|
||||||
|
if err != nil || sessionLimit < 0 {
|
||||||
|
sessionLimit = 0
|
||||||
|
}
|
||||||
|
if sessionLimit > 0 {
|
||||||
|
log.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit)
|
||||||
|
}
|
||||||
|
|
||||||
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
|
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
|
||||||
id: id,
|
id: id,
|
||||||
url: u,
|
url: u,
|
||||||
secret: []byte(secret),
|
secret: []byte(secret),
|
||||||
|
|
||||||
|
sessionLimit: uint64(sessionLimit),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -242,6 +242,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
|
||||||
new_config.AddOption("backend", "allowall", "false")
|
new_config.AddOption("backend", "allowall", "false")
|
||||||
new_config.AddOption("backend1", "url", "http://domain3.invalid")
|
new_config.AddOption("backend1", "url", "http://domain3.invalid")
|
||||||
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||||
|
new_config.AddOption("backend1", "sessionlimit", "10")
|
||||||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||||
n_cfg, err := NewBackendConfiguration(new_config)
|
n_cfg, err := NewBackendConfiguration(new_config)
|
||||||
|
@ -251,6 +252,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
|
||||||
|
|
||||||
original_config.RemoveOption("backend1", "url")
|
original_config.RemoveOption("backend1", "url")
|
||||||
original_config.AddOption("backend1", "url", "http://domain3.invalid")
|
original_config.AddOption("backend1", "url", "http://domain3.invalid")
|
||||||
|
original_config.AddOption("backend1", "sessionlimit", "10")
|
||||||
|
|
||||||
o_cfg.Reload(original_config)
|
o_cfg.Reload(original_config)
|
||||||
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
||||||
|
@ -310,6 +312,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
|
||||||
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
|
||||||
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
new_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||||
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||||
|
new_config.AddOption("backend2", "sessionlimit", "10")
|
||||||
n_cfg, err := NewBackendConfiguration(new_config)
|
n_cfg, err := NewBackendConfiguration(new_config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -319,6 +322,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
|
||||||
original_config.AddOption("backend", "backends", "backend1, backend2")
|
original_config.AddOption("backend", "backends", "backend1, backend2")
|
||||||
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
original_config.AddOption("backend2", "url", "http://domain2.invalid")
|
||||||
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
|
||||||
|
original_config.AddOption("backend2", "sessionlimit", "10")
|
||||||
|
|
||||||
o_cfg.Reload(original_config)
|
o_cfg.Reload(original_config)
|
||||||
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
if !reflect.DeepEqual(n_cfg, o_cfg) {
|
||||||
|
|
|
@ -329,6 +329,7 @@ func (s *ClientSession) closeAndWait(wait bool) {
|
||||||
s.virtualSessions = nil
|
s.virtualSessions = nil
|
||||||
s.releaseMcuObjects()
|
s.releaseMcuObjects()
|
||||||
s.clearClientLocked(nil)
|
s.clearClientLocked(nil)
|
||||||
|
s.backend.RemoveSession(s)
|
||||||
if atomic.CompareAndSwapInt32(&s.running, 1, 0) {
|
if atomic.CompareAndSwapInt32(&s.running, 1, 0) {
|
||||||
s.stopRun <- true
|
s.stopRun <- true
|
||||||
// Only wait if called from outside the Session goroutine.
|
// Only wait if called from outside the Session goroutine.
|
||||||
|
|
|
@ -705,6 +705,13 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := backend.AddSession(session); err != nil {
|
||||||
|
log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), err)
|
||||||
|
session.Close()
|
||||||
|
client.SendMessage(message.NewWrappedErrorServerMessage(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.mu.Lock()
|
h.mu.Lock()
|
||||||
if !client.IsConnected() {
|
if !client.IsConnected() {
|
||||||
// Client disconnected while waiting for backend response.
|
// Client disconnected while waiting for backend response.
|
||||||
|
|
|
@ -431,6 +431,120 @@ func TestClientHelloAllowAll(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientHelloSessionLimit(t *testing.T) {
|
||||||
|
hub, _, router, server, shutdown := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
|
||||||
|
config, err := getTestConfig(server)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.RemoveOption("backend", "allowed")
|
||||||
|
config.RemoveOption("backend", "secret")
|
||||||
|
config.AddOption("backend", "backends", "backend1, backend2")
|
||||||
|
|
||||||
|
config.AddOption("backend1", "url", server.URL+"/one")
|
||||||
|
config.AddOption("backend1", "secret", string(testBackendSecret))
|
||||||
|
config.AddOption("backend1", "sessionlimit", "1")
|
||||||
|
|
||||||
|
config.AddOption("backend2", "url", server.URL+"/two")
|
||||||
|
config.AddOption("backend2", "secret", string(testBackendSecret))
|
||||||
|
return config, nil
|
||||||
|
})
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
|
registerBackendHandlerUrl(t, router, "/one")
|
||||||
|
registerBackendHandlerUrl(t, router, "/two")
|
||||||
|
|
||||||
|
client := NewTestClient(t, server, hub)
|
||||||
|
defer client.CloseWithBye()
|
||||||
|
|
||||||
|
params1 := TestBackendClientAuthParams{
|
||||||
|
UserId: testDefaultUserId,
|
||||||
|
}
|
||||||
|
if err := client.SendHelloParams(server.URL+"/one", "client", params1); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if hello, err := client.RunUntilHello(ctx); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
if hello.Hello.UserId != testDefaultUserId {
|
||||||
|
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
|
||||||
|
}
|
||||||
|
if hello.Hello.SessionId == "" {
|
||||||
|
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The second client can't connect as it would exceed the session limit.
|
||||||
|
client2 := NewTestClient(t, server, hub)
|
||||||
|
defer client2.CloseWithBye()
|
||||||
|
|
||||||
|
params2 := TestBackendClientAuthParams{
|
||||||
|
UserId: testDefaultUserId + "2",
|
||||||
|
}
|
||||||
|
if err := client2.SendHelloParams(server.URL+"/one", "client", params2); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := client2.RunUntilMessage(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
if msg.Type != "error" || msg.Error == nil {
|
||||||
|
t.Errorf("Expected error message, got %+v", msg)
|
||||||
|
} else if msg.Error.Code != "session_limit_exceeded" {
|
||||||
|
t.Errorf("Expected error \"session_limit_exceeded\", got %+v", msg.Error.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The client can connect to a different backend.
|
||||||
|
if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hello, err := client2.RunUntilHello(ctx); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
if hello.Hello.UserId != testDefaultUserId+"2" {
|
||||||
|
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello.Hello)
|
||||||
|
}
|
||||||
|
if hello.Hello.SessionId == "" {
|
||||||
|
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the first client disconnects (and releases the session), a new one can connect.
|
||||||
|
client.CloseWithBye()
|
||||||
|
if err := client.WaitForClientRemoved(ctx); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client3 := NewTestClient(t, server, hub)
|
||||||
|
defer client3.CloseWithBye()
|
||||||
|
|
||||||
|
params3 := TestBackendClientAuthParams{
|
||||||
|
UserId: testDefaultUserId + "3",
|
||||||
|
}
|
||||||
|
if err := client3.SendHelloParams(server.URL+"/one", "client", params3); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hello, err := client3.RunUntilHello(ctx); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
if hello.Hello.UserId != testDefaultUserId+"3" {
|
||||||
|
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"3", hello.Hello)
|
||||||
|
}
|
||||||
|
if hello.Hello.SessionId == "" {
|
||||||
|
t.Errorf("Expected session id, got %+v", hello.Hello)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSessionIdsUnordered(t *testing.T) {
|
func TestSessionIdsUnordered(t *testing.T) {
|
||||||
hub, _, _, server, shutdown := CreateHubForTest(t)
|
hub, _, _, server, shutdown := CreateHubForTest(t)
|
||||||
defer shutdown()
|
defer shutdown()
|
||||||
|
|
Loading…
Reference in a new issue