nextcloud-spreed-signaling/backend_storage_static.go
2025-12-09 15:26:47 +01:00

448 lines
13 KiB
Go

/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"net/url"
"slices"
"strings"
"github.com/dlintw/goconf"
"github.com/strukturag/nextcloud-spreed-signaling/api"
"github.com/strukturag/nextcloud-spreed-signaling/internal"
"github.com/strukturag/nextcloud-spreed-signaling/log"
)
type backendStorageStatic struct {
backendStorageCommon
logger log.Logger
backendsById map[string]*Backend
// Deprecated
allowAll bool
commonSecret []byte
compatBackend *Backend
}
func NewBackendStorageStatic(logger log.Logger, config *goconf.ConfigFile, stats BackendStorageStats) (BackendStorage, error) {
allowAll, _ := config.GetBool("backend", "allowall")
allowHttp, _ := config.GetBool("backend", "allowhttp")
commonSecret, _ := GetStringOptionWithEnv(config, "backend", "secret")
sessionLimit, err := config.GetInt("backend", "sessionlimit")
if err != nil || sessionLimit < 0 {
sessionLimit = 0
}
backends := make(map[string][]*Backend)
backendsById := make(map[string]*Backend)
var compatBackend *Backend
numBackends := 0
if allowAll {
logger.Println("WARNING: All backend hostnames are allowed, only use for development!")
maxStreamBitrate, err := config.GetInt("backend", "maxstreambitrate")
if err != nil || maxStreamBitrate < 0 {
maxStreamBitrate = 0
}
maxScreenBitrate, err := config.GetInt("backend", "maxscreenbitrate")
if err != nil || maxScreenBitrate < 0 {
maxScreenBitrate = 0
}
compatBackend = &Backend{
id: "compat",
secret: []byte(commonSecret),
allowHttp: allowHttp,
sessionLimit: uint64(sessionLimit),
counted: true,
maxStreamBitrate: api.BandwidthFromBits(uint64(maxStreamBitrate)),
maxScreenBitrate: api.BandwidthFromBits(uint64(maxScreenBitrate)),
}
if sessionLimit > 0 {
logger.Printf("Allow a maximum of %d sessions", sessionLimit)
}
updateBackendStats(compatBackend)
backendsById[compatBackend.id] = compatBackend
numBackends++
} else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
added := make(map[string]*Backend)
for host, configuredBackends := range getConfiguredHosts(logger, backendIds, config, commonSecret) {
backends[host] = append(backends[host], configuredBackends...)
for _, be := range configuredBackends {
added[be.id] = be
}
}
for _, be := range added {
logger.Printf("Backend %s added for %s", be.id, strings.Join(be.urls, ", "))
backendsById[be.id] = be
updateBackendStats(be)
be.counted = true
}
numBackends += len(added)
} else if allowedUrls, _ := config.GetString("backend", "allowed"); allowedUrls != "" {
// Old-style configuration, only hosts are configured and are using a common secret.
allowMap := make(map[string]bool)
for u := range SplitEntries(allowedUrls, ",") {
if idx := strings.IndexByte(u, '/'); idx != -1 {
logger.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u)
if u = u[:idx]; u == "" {
continue
}
}
allowMap[strings.ToLower(u)] = true
}
if len(allowMap) == 0 {
logger.Println("WARNING: No backend hostnames are allowed, check your configuration!")
} else {
maxStreamBitrate, err := config.GetInt("backend", "maxstreambitrate")
if err != nil || maxStreamBitrate < 0 {
maxStreamBitrate = 0
}
maxScreenBitrate, err := config.GetInt("backend", "maxscreenbitrate")
if err != nil || maxScreenBitrate < 0 {
maxScreenBitrate = 0
}
compatBackend = &Backend{
id: "compat",
secret: []byte(commonSecret),
allowHttp: allowHttp,
sessionLimit: uint64(sessionLimit),
counted: true,
maxStreamBitrate: api.BandwidthFromBits(uint64(maxStreamBitrate)),
maxScreenBitrate: api.BandwidthFromBits(uint64(maxScreenBitrate)),
}
hosts := make([]string, 0, len(allowMap))
for host := range allowMap {
hosts = append(hosts, host)
backends[host] = []*Backend{compatBackend}
}
if len(hosts) > 1 {
logger.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.")
}
logger.Printf("Allowed backend hostnames: %s", hosts)
if sessionLimit > 0 {
logger.Printf("Allow a maximum of %d sessions", sessionLimit)
}
updateBackendStats(compatBackend)
backendsById[compatBackend.id] = compatBackend
numBackends++
}
}
if numBackends == 0 {
logger.Printf("WARNING: No backends configured, client connections will not be possible.")
}
stats.AddBackends(numBackends)
return &backendStorageStatic{
backendStorageCommon: backendStorageCommon{
backends: backends,
stats: stats,
},
logger: logger,
backendsById: backendsById,
allowAll: allowAll,
commonSecret: []byte(commonSecret),
compatBackend: compatBackend,
}, nil
}
func (s *backendStorageStatic) Close() {
}
// +checklocks:s.mu
func (s *backendStorageStatic) RemoveBackendsForHost(host string, seen map[string]seenState) {
if oldBackends := s.backends[host]; len(oldBackends) > 0 {
deleted := 0
for _, backend := range oldBackends {
if seen[backend.Id()] == seenDeleted {
continue
}
seen[backend.Id()] = seenDeleted
urls := slices.DeleteFunc(backend.urls, func(s string) bool {
return !strings.Contains(s, "://"+host)
})
s.logger.Printf("Backend %s removed for %s", backend.id, strings.Join(urls, ", "))
if len(urls) == len(backend.urls) && backend.counted {
deleteBackendStats(backend)
delete(s.backendsById, backend.Id())
deleted++
backend.counted = false
}
}
s.stats.RemoveBackends(deleted)
}
delete(s.backends, host)
}
type seenState int
const (
seenNotSeen seenState = iota
seenAdded
seenUpdated
seenDeleted
)
// +checklocks:s.mu
func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend, seen map[string]seenState) {
for existingIndex, existingBackend := range s.backends[host] {
found := false
index := 0
for _, newBackend := range backends {
if existingBackend.Equal(newBackend) {
found = true
backends = slices.Delete(backends, index, index+1)
break
} else if newBackend.id == existingBackend.id {
found = true
s.backends[host][existingIndex] = newBackend
backends = slices.Delete(backends, index, index+1)
if seen[newBackend.id] != seenUpdated {
seen[newBackend.id] = seenUpdated
s.logger.Printf("Backend %s updated for %s", newBackend.id, strings.Join(newBackend.urls, ", "))
updateBackendStats(newBackend)
newBackend.counted = existingBackend.counted
s.backendsById[newBackend.id] = newBackend
}
break
}
index++
}
if !found {
removed := s.backends[host][existingIndex]
s.backends[host] = slices.Delete(s.backends[host], existingIndex, existingIndex+1)
if seen[removed.id] != seenDeleted {
seen[removed.id] = seenDeleted
urls := slices.DeleteFunc(removed.urls, func(s string) bool {
return !strings.Contains(s, "://"+host)
})
s.logger.Printf("Backend %s removed for %s", removed.id, strings.Join(urls, ", "))
if len(urls) == len(removed.urls) && removed.counted {
deleteBackendStats(removed)
delete(s.backendsById, removed.Id())
s.stats.DecBackends()
removed.counted = false
}
}
}
}
s.backends[host] = append(s.backends[host], backends...)
addedBackends := 0
for _, added := range backends {
if seen[added.id] == seenAdded {
continue
}
seen[added.id] = seenAdded
if prev, found := s.backendsById[added.id]; found {
added.counted = prev.counted
} else {
s.backendsById[added.id] = added
}
s.logger.Printf("Backend %s added for %s", added.id, strings.Join(added.urls, ", "))
if !added.counted {
updateBackendStats(added)
addedBackends++
added.counted = true
}
}
s.stats.AddBackends(addedBackends)
}
func getConfiguredBackendIDs(backendIds string) (ids []string) {
seen := make(map[string]bool)
for id := range SplitEntries(backendIds, ",") {
if seen[id] {
continue
}
ids = append(ids, id)
seen[id] = true
}
return ids
}
func getConfiguredHosts(logger log.Logger, backendIds string, config *goconf.ConfigFile, commonSecret string) (hosts map[string][]*Backend) {
hosts = make(map[string][]*Backend)
seenUrls := make(map[string]string)
for _, id := range getConfiguredBackendIDs(backendIds) {
secret, _ := GetStringOptionWithEnv(config, id, "secret")
if secret == "" && commonSecret != "" {
logger.Printf("Backend %s has no own shared secret set, using common shared secret", id)
secret = commonSecret
}
if secret == "" {
logger.Printf("Backend %s is missing or incomplete, skipping", id)
continue
}
sessionLimit, err := config.GetInt(id, "sessionlimit")
if err != nil || sessionLimit < 0 {
sessionLimit = 0
}
if sessionLimit > 0 {
logger.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit)
}
maxStreamBitrate, err := config.GetInt(id, "maxstreambitrate")
if err != nil || maxStreamBitrate < 0 {
maxStreamBitrate = 0
}
maxScreenBitrate, err := config.GetInt(id, "maxscreenbitrate")
if err != nil || maxScreenBitrate < 0 {
maxScreenBitrate = 0
}
var urls []string
if u, _ := GetStringOptionWithEnv(config, id, "urls"); u != "" {
urls = slices.Sorted(SplitEntries(u, ","))
urls = slices.Compact(urls)
} else if u, _ := GetStringOptionWithEnv(config, id, "url"); u != "" {
if u = strings.TrimSpace(u); u != "" {
urls = []string{u}
}
}
if len(urls) == 0 {
logger.Printf("Backend %s is missing or incomplete, skipping", id)
continue
}
backend := &Backend{
id: id,
secret: []byte(secret),
maxStreamBitrate: api.BandwidthFromBits(uint64(maxStreamBitrate)),
maxScreenBitrate: api.BandwidthFromBits(uint64(maxScreenBitrate)),
sessionLimit: uint64(sessionLimit),
}
added := make(map[string]bool)
for _, u := range urls {
if u[len(u)-1] != '/' {
u += "/"
}
parsed, err := url.Parse(u)
if err != nil {
logger.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err)
continue
}
var changed bool
if parsed, changed = internal.CanonicalizeUrl(parsed); changed {
u = parsed.String()
}
if prev, found := seenUrls[u]; found {
logger.Printf("Url %s in backend %s was already used in backend %s, skipping", u, id, prev)
continue
}
seenUrls[u] = id
backend.urls = append(backend.urls, u)
if parsed.Scheme == "http" {
backend.allowHttp = true
}
if !added[parsed.Host] {
hosts[parsed.Host] = append(hosts[parsed.Host], backend)
added[parsed.Host] = true
}
}
}
return hosts
}
func (s *backendStorageStatic) Reload(config *goconf.ConfigFile) {
s.mu.Lock()
defer s.mu.Unlock()
if s.compatBackend != nil {
s.logger.Println("Old-style configuration active, reload is not supported")
return
}
commonSecret, _ := GetStringOptionWithEnv(config, "backend", "secret")
if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
configuredHosts := getConfiguredHosts(s.logger, backendIds, config, commonSecret)
// remove backends that are no longer configured
seen := make(map[string]seenState)
for hostname := range s.backends {
if _, ok := configuredHosts[hostname]; !ok {
s.RemoveBackendsForHost(hostname, seen)
}
}
// rewrite backends adding newly configured ones and rewriting existing ones
for hostname, configuredBackends := range configuredHosts {
s.UpsertHost(hostname, configuredBackends, seen)
}
} else {
// remove all backends
seen := make(map[string]seenState)
for hostname := range s.backends {
s.RemoveBackendsForHost(hostname, seen)
}
}
}
func (s *backendStorageStatic) GetCompatBackend() *Backend {
s.mu.RLock()
defer s.mu.RUnlock()
return s.compatBackend
}
func (s *backendStorageStatic) GetBackend(u *url.URL) *Backend {
s.mu.RLock()
defer s.mu.RUnlock()
if _, found := s.backends[u.Host]; !found {
if s.allowAll {
return s.compatBackend
}
return nil
}
return s.getBackendLocked(u)
}