Merge pull request #67 from strukturag/limit-sessions

Support limiting number of sessions per backend.
This commit is contained in:
Joachim Bauch 2021-01-07 09:39:35 +01:00 committed by GitHub
commit b89e017ae4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 0 deletions

View File

@ -82,6 +82,10 @@ connectionsperhost = 8
# same value as configured in the Nextcloud admin ui.
#secret = the-shared-secret
# Limit the number of sessions that are allowed to connect to this backend.
# Omit or set to 0 to not limit the number of sessions.
#sessionlimit = 10
#[another-backend]
# URL of the Nextcloud instance
#url = https://cloud.otherdomain.invalid

View File

@ -104,6 +104,10 @@ func (m *ClientMessage) NewErrorServerMessage(e *Error) *ServerMessage {
}
func (m *ClientMessage) NewWrappedErrorServerMessage(e error) *ServerMessage {
if e, ok := e.(*Error); ok {
return m.NewErrorServerMessage(e)
}
return m.NewErrorServerMessage(NewError("internal_error", e.Error()))
}

View File

@ -26,15 +26,24 @@ import (
"net/url"
"reflect"
"strings"
"sync"
"github.com/dlintw/goconf"
)
var (
SessionLimitExceeded = NewError("session_limit_exceeded", "Too many sessions connected for this backend.")
)
type Backend struct {
id string
url string
secret []byte
compat bool
sessionLimit uint64
sessionsLock sync.Mutex
sessions map[string]bool
}
func (b *Backend) Id() string {
@ -49,6 +58,36 @@ func (b *Backend) IsCompat() bool {
return b.compat
}
func (b *Backend) AddSession(session Session) error {
if session.ClientType() == HelloClientTypeInternal || session.ClientType() == HelloClientTypeVirtual {
// Internal and virtual sessions are not counting to the limit.
return nil
}
if b.sessionLimit == 0 {
// Not limited
return nil
}
b.sessionsLock.Lock()
defer b.sessionsLock.Unlock()
if b.sessions == nil {
b.sessions = make(map[string]bool)
} else if uint64(len(b.sessions)) >= b.sessionLimit {
return SessionLimitExceeded
}
b.sessions[session.PublicId()] = true
return nil
}
func (b *Backend) RemoveSession(session Session) {
b.sessionsLock.Lock()
defer b.sessionsLock.Unlock()
delete(b.sessions, session.PublicId())
}
type BackendConfiguration struct {
backends map[string][]*Backend
@ -61,6 +100,10 @@ type BackendConfiguration struct {
func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, error) {
allowAll, _ := config.GetBool("backend", "allowall")
commonSecret, _ := config.GetString("backend", "secret")
sessionLimit, err := config.GetInt("backend", "sessionlimit")
if err != nil || sessionLimit < 0 {
sessionLimit = 0
}
backends := make(map[string][]*Backend)
var compatBackend *Backend
if allowAll {
@ -69,6 +112,11 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration,
id: "compat",
secret: []byte(commonSecret),
compat: true,
sessionLimit: uint64(sessionLimit),
}
if sessionLimit > 0 {
log.Printf("Allow a maximum of %d sessions", sessionLimit)
}
} else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
for host, configuredBackends := range getConfiguredHosts(backendIds, config) {
@ -98,6 +146,8 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration,
id: "compat",
secret: []byte(commonSecret),
compat: true,
sessionLimit: uint64(sessionLimit),
}
hosts := make([]string, 0, len(allowMap))
for host := range allowMap {
@ -108,6 +158,9 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration,
log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.")
}
log.Printf("Allowed backend hostnames: %s\n", hosts)
if sessionLimit > 0 {
log.Printf("Allow a maximum of %d sessions", sessionLimit)
}
}
}
@ -208,10 +261,20 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map
continue
}
sessionLimit, err := config.GetInt(id, "sessionlimit")
if err != nil || sessionLimit < 0 {
sessionLimit = 0
}
if sessionLimit > 0 {
log.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit)
}
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
id: id,
url: u,
secret: []byte(secret),
sessionLimit: uint64(sessionLimit),
})
}

View File

@ -242,6 +242,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
new_config.AddOption("backend", "allowall", "false")
new_config.AddOption("backend1", "url", "http://domain3.invalid")
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
new_config.AddOption("backend1", "sessionlimit", "10")
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(new_config)
@ -251,6 +252,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
original_config.RemoveOption("backend1", "url")
original_config.AddOption("backend1", "url", "http://domain3.invalid")
original_config.AddOption("backend1", "sessionlimit", "10")
o_cfg.Reload(original_config)
if !reflect.DeepEqual(n_cfg, o_cfg) {
@ -310,6 +312,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
new_config.AddOption("backend2", "sessionlimit", "10")
n_cfg, err := NewBackendConfiguration(new_config)
if err != nil {
t.Fatal(err)
@ -319,6 +322,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
original_config.AddOption("backend", "backends", "backend1, backend2")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
original_config.AddOption("backend2", "sessionlimit", "10")
o_cfg.Reload(original_config)
if !reflect.DeepEqual(n_cfg, o_cfg) {

View File

@ -329,6 +329,7 @@ func (s *ClientSession) closeAndWait(wait bool) {
s.virtualSessions = nil
s.releaseMcuObjects()
s.clearClientLocked(nil)
s.backend.RemoveSession(s)
if atomic.CompareAndSwapInt32(&s.running, 1, 0) {
s.stopRun <- true
// Only wait if called from outside the Session goroutine.

View File

@ -705,6 +705,13 @@ func (h *Hub) processRegister(client *Client, message *ClientMessage, backend *B
return
}
if err := backend.AddSession(session); err != nil {
log.Printf("Error adding session %s to backend %s: %s", session.PublicId(), backend.Id(), err)
session.Close()
client.SendMessage(message.NewWrappedErrorServerMessage(err))
return
}
h.mu.Lock()
if !client.IsConnected() {
// Client disconnected while waiting for backend response.

View File

@ -431,6 +431,120 @@ func TestClientHelloAllowAll(t *testing.T) {
}
}
func TestClientHelloSessionLimit(t *testing.T) {
hub, _, router, server, shutdown := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) {
config, err := getTestConfig(server)
if err != nil {
return nil, err
}
config.RemoveOption("backend", "allowed")
config.RemoveOption("backend", "secret")
config.AddOption("backend", "backends", "backend1, backend2")
config.AddOption("backend1", "url", server.URL+"/one")
config.AddOption("backend1", "secret", string(testBackendSecret))
config.AddOption("backend1", "sessionlimit", "1")
config.AddOption("backend2", "url", server.URL+"/two")
config.AddOption("backend2", "secret", string(testBackendSecret))
return config, nil
})
defer shutdown()
registerBackendHandlerUrl(t, router, "/one")
registerBackendHandlerUrl(t, router, "/two")
client := NewTestClient(t, server, hub)
defer client.CloseWithBye()
params1 := TestBackendClientAuthParams{
UserId: testDefaultUserId,
}
if err := client.SendHelloParams(server.URL+"/one", "client", params1); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if hello, err := client.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
}
// The second client can't connect as it would exceed the session limit.
client2 := NewTestClient(t, server, hub)
defer client2.CloseWithBye()
params2 := TestBackendClientAuthParams{
UserId: testDefaultUserId + "2",
}
if err := client2.SendHelloParams(server.URL+"/one", "client", params2); err != nil {
t.Fatal(err)
}
msg, err := client2.RunUntilMessage(ctx)
if err != nil {
t.Error(err)
} else {
if msg.Type != "error" || msg.Error == nil {
t.Errorf("Expected error message, got %+v", msg)
} else if msg.Error.Code != "session_limit_exceeded" {
t.Errorf("Expected error \"session_limit_exceeded\", got %+v", msg.Error.Code)
}
}
// The client can connect to a different backend.
if err := client2.SendHelloParams(server.URL+"/two", "client", params2); err != nil {
t.Fatal(err)
}
if hello, err := client2.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId+"2" {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
}
// If the first client disconnects (and releases the session), a new one can connect.
client.CloseWithBye()
if err := client.WaitForClientRemoved(ctx); err != nil {
t.Error(err)
}
client3 := NewTestClient(t, server, hub)
defer client3.CloseWithBye()
params3 := TestBackendClientAuthParams{
UserId: testDefaultUserId + "3",
}
if err := client3.SendHelloParams(server.URL+"/one", "client", params3); err != nil {
t.Fatal(err)
}
if hello, err := client3.RunUntilHello(ctx); err != nil {
t.Error(err)
} else {
if hello.Hello.UserId != testDefaultUserId+"3" {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"3", hello.Hello)
}
if hello.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello.Hello)
}
}
}
func TestSessionIdsUnordered(t *testing.T) {
hub, _, _, server, shutdown := CreateHubForTest(t)
defer shutdown()