Allow configuring backends through etcd.

This commit is contained in:
Joachim Bauch 2022-06-30 11:34:32 +02:00
parent 01858a89f4
commit 24eab34da7
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
13 changed files with 903 additions and 295 deletions

View file

@ -28,7 +28,10 @@ import (
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
)
const (
@ -321,3 +324,39 @@ type TurnCredentials struct {
TTL int64 `json:"ttl"`
URIs []string `json:"uris"`
}
// Information on a backend in the etcd cluster.
type BackendInformationEtcd struct {
parsedUrl *url.URL
Url string `json:"url"`
Secret string `json:"secret"`
MaxStreamBitrate int `json:"maxstreambitrate,omitempty"`
MaxScreenBitrate int `json:"maxscreenbitrate,omitempty"`
SessionLimit uint64 `json:"sessionlimit,omitempty"`
}
func (p *BackendInformationEtcd) CheckValid() error {
if p.Url == "" {
return fmt.Errorf("url missing")
}
if p.Secret == "" {
return fmt.Errorf("secret missing")
}
parsedUrl, err := url.Parse(p.Url)
if err != nil {
return fmt.Errorf("invalid url: %w", err)
}
if strings.Contains(parsedUrl.Host, ":") && hasStandardPort(parsedUrl) {
parsedUrl.Host = parsedUrl.Hostname()
p.Url = parsedUrl.String()
}
p.parsedUrl = parsedUrl
return nil
}

View file

@ -50,8 +50,8 @@ type BackendClient struct {
capabilities *Capabilities
}
func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string) (*BackendClient, error) {
backends, err := NewBackendConfiguration(config)
func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost int, version string, etcdClient *EtcdClient) (*BackendClient, error) {
backends, err := NewBackendConfiguration(config, etcdClient)
if err != nil {
return nil, err
}
@ -80,6 +80,10 @@ func NewBackendClient(config *goconf.ConfigFile, maxConcurrentRequestsPerHost in
}, nil
}
func (b *BackendClient) Close() {
b.backends.Close()
}
func (b *BackendClient) Reload(config *goconf.ConfigFile) {
b.backends.Reload(config)
}

View file

@ -95,7 +95,7 @@ func TestPostOnRedirect(t *testing.T) {
if u.Scheme == "http" {
config.AddOption("backend", "allowhttp", "true")
}
client, err := NewBackendClient(config, 1, "0.0")
client, err := NewBackendClient(config, 1, "0.0", nil)
if err != nil {
t.Fatal(err)
}
@ -134,7 +134,7 @@ func TestPostOnRedirectDifferentHost(t *testing.T) {
if u.Scheme == "http" {
config.AddOption("backend", "allowhttp", "true")
}
client, err := NewBackendClient(config, 1, "0.0")
client, err := NewBackendClient(config, 1, "0.0", nil)
if err != nil {
t.Fatal(err)
}
@ -187,7 +187,7 @@ func TestPostOnRedirectStatusFound(t *testing.T) {
if u.Scheme == "http" {
config.AddOption("backend", "allowhttp", "true")
}
client, err := NewBackendClient(config, 1, "0.0")
client, err := NewBackendClient(config, 1, "0.0", nil)
if err != nil {
t.Fatal(err)
}

View file

@ -22,15 +22,21 @@
package signaling
import (
"log"
"fmt"
"net/url"
"reflect"
"strings"
"sync"
"github.com/dlintw/goconf"
)
const (
BackendTypeStatic = "static"
BackendTypeEtcd = "etcd"
DefaultBackendType = BackendTypeStatic
)
var (
SessionLimitExceeded = NewError("session_limit_exceeded", "Too many sessions connected for this backend.")
)
@ -105,271 +111,43 @@ func (b *Backend) RemoveSession(session Session) {
delete(b.sessions, session.PublicId())
}
type BackendConfiguration struct {
type BackendStorage interface {
Close()
Reload(config *goconf.ConfigFile)
GetCompatBackend() *Backend
GetBackend(u *url.URL) *Backend
GetBackends() []*Backend
}
type backendStorageCommon struct {
mu sync.RWMutex
backends map[string][]*Backend
// Deprecated
allowAll bool
commonSecret []byte
compatBackend *Backend
}
func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, error) {
allowAll, _ := config.GetBool("backend", "allowall")
allowHttp, _ := config.GetBool("backend", "allowhttp")
commonSecret, _ := config.GetString("backend", "secret")
sessionLimit, err := config.GetInt("backend", "sessionlimit")
if err != nil || sessionLimit < 0 {
sessionLimit = 0
func (s *backendStorageCommon) GetBackends() []*Backend {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*Backend
for _, entries := range s.backends {
result = append(result, entries...)
}
backends := make(map[string][]*Backend)
var compatBackend *Backend
numBackends := 0
if allowAll {
log.Println("WARNING: All backend hostnames are allowed, only use for development!")
compatBackend = &Backend{
id: "compat",
secret: []byte(commonSecret),
compat: true,
allowHttp: allowHttp,
sessionLimit: uint64(sessionLimit),
}
if sessionLimit > 0 {
log.Printf("Allow a maximum of %d sessions", sessionLimit)
}
numBackends++
} else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
for host, configuredBackends := range getConfiguredHosts(backendIds, config) {
backends[host] = append(backends[host], configuredBackends...)
for _, be := range configuredBackends {
log.Printf("Backend %s added for %s", be.id, be.url)
}
numBackends += len(configuredBackends)
}
} else if allowedUrls, _ := config.GetString("backend", "allowed"); allowedUrls != "" {
// Old-style configuration, only hosts are configured and are using a common secret.
allowMap := make(map[string]bool)
for _, u := range strings.Split(allowedUrls, ",") {
u = strings.TrimSpace(u)
if idx := strings.IndexByte(u, '/'); idx != -1 {
log.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u)
u = u[:idx]
}
if u != "" {
allowMap[strings.ToLower(u)] = true
}
}
if len(allowMap) == 0 {
log.Println("WARNING: No backend hostnames are allowed, check your configuration!")
} else {
compatBackend = &Backend{
id: "compat",
secret: []byte(commonSecret),
compat: true,
allowHttp: allowHttp,
sessionLimit: uint64(sessionLimit),
}
hosts := make([]string, 0, len(allowMap))
for host := range allowMap {
hosts = append(hosts, host)
backends[host] = []*Backend{compatBackend}
}
if len(hosts) > 1 {
log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.")
}
log.Printf("Allowed backend hostnames: %s", hosts)
if sessionLimit > 0 {
log.Printf("Allow a maximum of %d sessions", sessionLimit)
}
numBackends++
}
}
RegisterBackendConfigurationStats()
statsBackendsCurrent.Add(float64(numBackends))
return &BackendConfiguration{
backends: backends,
allowAll: allowAll,
commonSecret: []byte(commonSecret),
compatBackend: compatBackend,
}, nil
return result
}
func (b *BackendConfiguration) RemoveBackendsForHost(host string) {
if oldBackends := b.backends[host]; len(oldBackends) > 0 {
for _, backend := range oldBackends {
log.Printf("Backend %s removed for %s", backend.id, backend.url)
}
statsBackendsCurrent.Sub(float64(len(oldBackends)))
}
delete(b.backends, host)
}
func (s *backendStorageCommon) getBackendLocked(u *url.URL) *Backend {
s.mu.RLock()
defer s.mu.RUnlock()
func (b *BackendConfiguration) UpsertHost(host string, backends []*Backend) {
for existingIndex, existingBackend := range b.backends[host] {
found := false
index := 0
for _, newBackend := range backends {
if reflect.DeepEqual(existingBackend, newBackend) { // otherwise we could manually compare the struct members here
found = true
backends = append(backends[:index], backends[index+1:]...)
break
} else if newBackend.id == existingBackend.id {
found = true
b.backends[host][existingIndex] = newBackend
backends = append(backends[:index], backends[index+1:]...)
log.Printf("Backend %s updated for %s", newBackend.id, newBackend.url)
break
}
index++
}
if !found {
removed := b.backends[host][existingIndex]
log.Printf("Backend %s removed for %s", removed.id, removed.url)
b.backends[host] = append(b.backends[host][:existingIndex], b.backends[host][existingIndex+1:]...)
statsBackendsCurrent.Dec()
}
}
b.backends[host] = append(b.backends[host], backends...)
for _, added := range backends {
log.Printf("Backend %s added for %s", added.id, added.url)
}
statsBackendsCurrent.Add(float64(len(backends)))
}
func getConfiguredBackendIDs(backendIds string) (ids []string) {
seen := make(map[string]bool)
for _, id := range strings.Split(backendIds, ",") {
id = strings.TrimSpace(id)
if id == "" {
continue
}
if seen[id] {
continue
}
ids = append(ids, id)
seen[id] = true
}
return ids
}
func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map[string][]*Backend) {
hosts = make(map[string][]*Backend)
for _, id := range getConfiguredBackendIDs(backendIds) {
u, _ := config.GetString(id, "url")
if u == "" {
log.Printf("Backend %s is missing or incomplete, skipping", id)
continue
}
if u[len(u)-1] != '/' {
u += "/"
}
parsed, err := url.Parse(u)
if err != nil {
log.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err)
continue
}
if strings.Contains(parsed.Host, ":") && hasStandardPort(parsed) {
parsed.Host = parsed.Hostname()
u = parsed.String()
}
secret, _ := config.GetString(id, "secret")
if u == "" || secret == "" {
log.Printf("Backend %s is missing or incomplete, skipping", id)
continue
}
sessionLimit, err := config.GetInt(id, "sessionlimit")
if err != nil || sessionLimit < 0 {
sessionLimit = 0
}
if sessionLimit > 0 {
log.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit)
}
maxStreamBitrate, err := config.GetInt(id, "maxstreambitrate")
if err != nil || maxStreamBitrate < 0 {
maxStreamBitrate = 0
}
maxScreenBitrate, err := config.GetInt(id, "maxscreenbitrate")
if err != nil || maxScreenBitrate < 0 {
maxScreenBitrate = 0
}
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
id: id,
url: u,
secret: []byte(secret),
allowHttp: parsed.Scheme == "http",
maxStreamBitrate: maxStreamBitrate,
maxScreenBitrate: maxScreenBitrate,
sessionLimit: uint64(sessionLimit),
})
}
return hosts
}
func (b *BackendConfiguration) Reload(config *goconf.ConfigFile) {
if b.compatBackend != nil {
log.Println("Old-style configuration active, reload is not supported")
return
}
if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
configuredHosts := getConfiguredHosts(backendIds, config)
// remove backends that are no longer configured
for hostname := range b.backends {
if _, ok := configuredHosts[hostname]; !ok {
b.RemoveBackendsForHost(hostname)
}
}
// rewrite backends adding newly configured ones and rewriting existing ones
for hostname, configuredBackends := range configuredHosts {
b.UpsertHost(hostname, configuredBackends)
}
}
}
func (b *BackendConfiguration) GetCompatBackend() *Backend {
return b.compatBackend
}
func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend {
if strings.Contains(u.Host, ":") && hasStandardPort(u) {
u.Host = u.Hostname()
}
entries, found := b.backends[u.Host]
entries, found := s.backends[u.Host]
if !found {
if b.allowAll {
return b.compatBackend
}
return nil
}
s := u.String()
if s[len(s)-1] != '/' {
s += "/"
url := u.String()
if url[len(url)-1] != '/' {
url += "/"
}
for _, entry := range entries {
if !entry.IsUrlAllowed(u) {
@ -379,7 +157,7 @@ func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend {
if entry.url == "" {
// Old-style configuration, only hosts are configured.
return entry
} else if strings.HasPrefix(s, entry.url) {
} else if strings.HasPrefix(url, entry.url) {
return entry
}
}
@ -387,12 +165,59 @@ func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend {
return nil
}
func (b *BackendConfiguration) GetBackends() []*Backend {
var result []*Backend
for _, entries := range b.backends {
result = append(result, entries...)
type BackendConfiguration struct {
storage BackendStorage
}
func NewBackendConfiguration(config *goconf.ConfigFile, etcdClient *EtcdClient) (*BackendConfiguration, error) {
backendType, _ := config.GetString("backend", "backendtype")
if backendType == "" {
backendType = DefaultBackendType
}
return result
RegisterBackendConfigurationStats()
var storage BackendStorage
var err error
switch backendType {
case BackendTypeStatic:
storage, err = NewBackendStorageStatic(config)
case BackendTypeEtcd:
storage, err = NewBackendStorageEtcd(config, etcdClient)
default:
err = fmt.Errorf("unknown backend type: %s", backendType)
}
if err != nil {
return nil, err
}
return &BackendConfiguration{
storage: storage,
}, nil
}
func (b *BackendConfiguration) Close() {
b.storage.Close()
}
func (b *BackendConfiguration) Reload(config *goconf.ConfigFile) {
b.storage.Reload(config)
}
func (b *BackendConfiguration) GetCompatBackend() *Backend {
return b.storage.GetCompatBackend()
}
func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend {
if strings.Contains(u.Host, ":") && hasStandardPort(u) {
u.Host = u.Hostname()
}
return b.storage.GetBackend(u)
}
func (b *BackendConfiguration) GetBackends() []*Backend {
return b.storage.GetBackends()
}
func (b *BackendConfiguration) IsUrlAllowed(u *url.URL) bool {
@ -416,5 +241,5 @@ func (b *BackendConfiguration) GetSecret(u *url.URL) []byte {
return nil
}
return entry.secret
return entry.Secret()
}

View file

@ -23,8 +23,10 @@ package signaling
import (
"bytes"
"context"
"net/url"
"reflect"
"sort"
"testing"
"github.com/dlintw/goconf"
@ -104,7 +106,7 @@ func TestIsUrlAllowed_Compat(t *testing.T) {
config.AddOption("backend", "allowed", "domain.invalid")
config.AddOption("backend", "allowhttp", "true")
config.AddOption("backend", "secret", string(testBackendSecret))
cfg, err := NewBackendConfiguration(config)
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
@ -125,7 +127,7 @@ func TestIsUrlAllowed_CompatForceHttps(t *testing.T) {
config := goconf.NewConfigFile()
config.AddOption("backend", "allowed", "domain.invalid")
config.AddOption("backend", "secret", string(testBackendSecret))
cfg, err := NewBackendConfiguration(config)
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
@ -170,7 +172,7 @@ func TestIsUrlAllowed(t *testing.T) {
config.AddOption("baz", "secret", string(testBackendSecret)+"-baz")
config.AddOption("lala", "url", "https://otherdomain.invalid/")
config.AddOption("lala", "secret", string(testBackendSecret)+"-lala")
cfg, err := NewBackendConfiguration(config)
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
@ -187,7 +189,7 @@ func TestIsUrlAllowed_EmptyAllowlist(t *testing.T) {
config := goconf.NewConfigFile()
config.AddOption("backend", "allowed", "")
config.AddOption("backend", "secret", string(testBackendSecret))
cfg, err := NewBackendConfiguration(config)
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
@ -207,7 +209,7 @@ func TestIsUrlAllowed_AllowAll(t *testing.T) {
config.AddOption("backend", "allowall", "true")
config.AddOption("backend", "allowed", "")
config.AddOption("backend", "secret", string(testBackendSecret))
cfg, err := NewBackendConfiguration(config)
cfg, err := NewBackendConfiguration(config, nil)
if err != nil {
t.Fatal(err)
}
@ -247,7 +249,7 @@ func TestBackendReloadNoChange(t *testing.T) {
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config)
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
@ -260,7 +262,7 @@ func TestBackendReloadNoChange(t *testing.T) {
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(new_config)
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
@ -282,7 +284,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config)
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
@ -296,7 +298,7 @@ func TestBackendReloadChangeExistingURL(t *testing.T) {
new_config.AddOption("backend1", "sessionlimit", "10")
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(new_config)
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
@ -322,7 +324,7 @@ func TestBackendReloadChangeSecret(t *testing.T) {
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config)
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
@ -335,7 +337,7 @@ func TestBackendReloadChangeSecret(t *testing.T) {
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend3")
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
n_cfg, err := NewBackendConfiguration(new_config)
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
@ -358,7 +360,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
original_config.AddOption("backend", "allowall", "false")
original_config.AddOption("backend1", "url", "http://domain1.invalid")
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
o_cfg, err := NewBackendConfiguration(original_config)
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
@ -372,7 +374,7 @@ func TestBackendReloadAddBackend(t *testing.T) {
new_config.AddOption("backend2", "url", "http://domain2.invalid")
new_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
new_config.AddOption("backend2", "sessionlimit", "10")
n_cfg, err := NewBackendConfiguration(new_config)
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
@ -400,7 +402,7 @@ func TestBackendReloadRemoveHost(t *testing.T) {
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain2.invalid")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config)
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
@ -411,7 +413,7 @@ func TestBackendReloadRemoveHost(t *testing.T) {
new_config.AddOption("backend", "allowall", "false")
new_config.AddOption("backend1", "url", "http://domain1.invalid")
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
n_cfg, err := NewBackendConfiguration(new_config)
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
@ -437,7 +439,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
original_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
original_config.AddOption("backend2", "url", "http://domain1.invalid/bar/")
original_config.AddOption("backend2", "secret", string(testBackendSecret)+"-backend2")
o_cfg, err := NewBackendConfiguration(original_config)
o_cfg, err := NewBackendConfiguration(original_config, nil)
if err != nil {
t.Fatal(err)
}
@ -448,7 +450,7 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
new_config.AddOption("backend", "allowall", "false")
new_config.AddOption("backend1", "url", "http://domain1.invalid/foo/")
new_config.AddOption("backend1", "secret", string(testBackendSecret)+"-backend1")
n_cfg, err := NewBackendConfiguration(new_config)
n_cfg, err := NewBackendConfiguration(new_config, nil)
if err != nil {
t.Fatal(err)
}
@ -464,3 +466,155 @@ func TestBackendReloadRemoveBackendFromSharedHost(t *testing.T) {
t.Error("BackendConfiguration should be equal after Reload")
}
}
func sortBackends(backends []*Backend) []*Backend {
result := make([]*Backend, len(backends))
copy(result, backends)
sort.Slice(result, func(i, j int) bool {
return result[i].Id() < result[j].Id()
})
return result
}
func mustParse(s string) *url.URL {
p, err := url.Parse(s)
if err != nil {
panic(err)
}
return p
}
func TestBackendConfiguration_Etcd(t *testing.T) {
etcd, client := NewEtcdClientForTest(t)
url1 := "https://domain1.invalid/foo"
initialSecret1 := string(testBackendSecret) + "-backend1-initial"
secret1 := string(testBackendSecret) + "-backend1"
SetEtcdValue(etcd, "/backends/1_one", []byte("{\"url\":\""+url1+"\",\"secret\":\""+initialSecret1+"\"}"))
config := goconf.NewConfigFile()
config.AddOption("backend", "backendtype", "etcd")
config.AddOption("backend", "backendprefix", "/backends")
cfg, err := NewBackendConfiguration(config, client)
if err != nil {
t.Fatal(err)
}
defer cfg.Close()
storage := cfg.storage.(*backendStorageEtcd)
ch := make(chan bool, 1)
storage.SetWakeupForTesting(ch)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
if err := storage.WaitForInitialized(ctx); err != nil {
t.Fatal(err)
}
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
t.Errorf("Expected one backend, got %+v", backends)
} else if backends[0].url != url1 {
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
} else if string(backends[0].secret) != initialSecret1 {
t.Errorf("Expected backend secret %s, got %s", initialSecret1, string(backends[0].secret))
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
}
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/1_one", []byte("{\"url\":\""+url1+"\",\"secret\":\""+secret1+"\"}"))
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
t.Errorf("Expected one backend, got %+v", backends)
} else if backends[0].url != url1 {
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
} else if string(backends[0].secret) != secret1 {
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
}
url2 := "https://domain1.invalid/bar"
secret2 := string(testBackendSecret) + "-backend2"
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/2_two", []byte("{\"url\":\""+url2+"\",\"secret\":\""+secret2+"\"}"))
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 {
t.Errorf("Expected two backends, got %+v", backends)
} else if backends[0].url != url1 {
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
} else if string(backends[0].secret) != secret1 {
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
} else if backends[1].url != url2 {
t.Errorf("Expected backend url %s, got %s", url2, backends[1].url)
} else if string(backends[1].secret) != secret2 {
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret))
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
t.Errorf("Expected backend %+v, got %+v", backends[1], backend)
}
url3 := "https://domain2.invalid/foo"
secret3 := string(testBackendSecret) + "-backend3"
drainWakeupChannel(ch)
SetEtcdValue(etcd, "/backends/3_three", []byte("{\"url\":\""+url3+"\",\"secret\":\""+secret3+"\"}"))
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 3 {
t.Errorf("Expected three backends, got %+v", backends)
} else if backends[0].url != url1 {
t.Errorf("Expected backend url %s, got %s", url1, backends[0].url)
} else if string(backends[0].secret) != secret1 {
t.Errorf("Expected backend secret %s, got %s", secret1, string(backends[0].secret))
} else if backends[1].url != url2 {
t.Errorf("Expected backend url %s, got %s", url2, backends[1].url)
} else if string(backends[1].secret) != secret2 {
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[1].secret))
} else if backends[2].url != url3 {
t.Errorf("Expected backend url %s, got %s", url3, backends[2].url)
} else if string(backends[2].secret) != secret3 {
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[2].secret))
} else if backend := cfg.GetBackend(mustParse(url1)); backend != backends[0] {
t.Errorf("Expected backend %+v, got %+v", backends[0], backend)
} else if backend := cfg.GetBackend(mustParse(url2)); backend != backends[1] {
t.Errorf("Expected backend %+v, got %+v", backends[1], backend)
} else if backend := cfg.GetBackend(mustParse(url3)); backend != backends[2] {
t.Errorf("Expected backend %+v, got %+v", backends[2], backend)
}
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/backends/1_one")
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 2 {
t.Errorf("Expected two backends, got %+v", backends)
} else if backends[0].url != url2 {
t.Errorf("Expected backend url %s, got %s", url2, backends[0].url)
} else if string(backends[0].secret) != secret2 {
t.Errorf("Expected backend secret %s, got %s", secret2, string(backends[0].secret))
} else if backends[1].url != url3 {
t.Errorf("Expected backend url %s, got %s", url3, backends[1].url)
} else if string(backends[1].secret) != secret3 {
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[1].secret))
}
drainWakeupChannel(ch)
DeleteEtcdValue(etcd, "/backends/2_two")
<-ch
if backends := sortBackends(cfg.GetBackends()); len(backends) != 1 {
t.Errorf("Expected one backend, got %+v", backends)
} else if backends[0].url != url3 {
t.Errorf("Expected backend url %s, got %s", url3, backends[0].url)
} else if string(backends[0].secret) != secret3 {
t.Errorf("Expected backend secret %s, got %s", secret3, string(backends[0].secret))
}
if _, found := storage.backends["domain1.invalid"]; found {
t.Errorf("Should have removed host information for %s", "domain1.invalid")
}
}

View file

@ -88,7 +88,7 @@ func CreateBackendServerForTestFromConfig(t *testing.T, config *goconf.ConfigFil
config.AddOption("clients", "internalsecret", string(testInternalSecret))
config.AddOption("geoip", "url", "none")
events := getAsyncEventsForTest(t)
hub, err := NewHub(config, events, nil, nil, r, "no-version")
hub, err := NewHub(config, events, nil, nil, nil, r, "no-version")
if err != nil {
t.Fatal(err)
}
@ -162,7 +162,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
events1.Close()
})
client1 := NewGrpcClientsForTest(t, addr2)
hub1, err := NewHub(config1, events1, grpcServer1, client1, r1, "no-version")
hub1, err := NewHub(config1, events1, grpcServer1, client1, nil, r1, "no-version")
if err != nil {
t.Fatal(err)
}
@ -191,7 +191,7 @@ func CreateBackendServerWithClusteringForTestFromConfig(t *testing.T, config1 *g
events2.Close()
})
client2 := NewGrpcClientsForTest(t, addr1)
hub2, err := NewHub(config2, events2, grpcServer2, client2, r2, "no-version")
hub2, err := NewHub(config2, events2, grpcServer2, client2, nil, r2, "no-version")
if err != nil {
t.Fatal(err)
}

256
backend_storage_etcd.go Normal file
View file

@ -0,0 +1,256 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"encoding/json"
"fmt"
"log"
"net/url"
"time"
"github.com/dlintw/goconf"
clientv3 "go.etcd.io/etcd/client/v3"
)
type backendStorageEtcd struct {
backendStorageCommon
etcdClient *EtcdClient
keyPrefix string
keyInfos map[string]*BackendInformationEtcd
initializedCtx context.Context
initializedFunc context.CancelFunc
wakeupChanForTesting chan bool
}
func NewBackendStorageEtcd(config *goconf.ConfigFile, etcdClient *EtcdClient) (BackendStorage, error) {
if etcdClient == nil || !etcdClient.IsConfigured() {
return nil, fmt.Errorf("no etcd endpoints configured")
}
keyPrefix, _ := config.GetString("backend", "backendprefix")
if keyPrefix == "" {
return nil, fmt.Errorf("no backend prefix configured")
}
initializedCtx, initializedFunc := context.WithCancel(context.Background())
result := &backendStorageEtcd{
backendStorageCommon: backendStorageCommon{
backends: make(map[string][]*Backend),
},
etcdClient: etcdClient,
keyPrefix: keyPrefix,
keyInfos: make(map[string]*BackendInformationEtcd),
initializedCtx: initializedCtx,
initializedFunc: initializedFunc,
}
etcdClient.AddListener(result)
return result, nil
}
func (s *backendStorageEtcd) WaitForInitialized(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-s.initializedCtx.Done():
return nil
}
}
func (s *backendStorageEtcd) SetWakeupForTesting(ch chan bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.wakeupChanForTesting = ch
}
func (s *backendStorageEtcd) wakeupForTesting() {
if s.wakeupChanForTesting == nil {
return
}
select {
case s.wakeupChanForTesting <- true:
default:
}
}
func (s *backendStorageEtcd) EtcdClientCreated(client *EtcdClient) {
go func() {
if err := client.Watch(context.Background(), s.keyPrefix, s, clientv3.WithPrefix()); err != nil {
log.Printf("Error processing watch for %s: %s", s.keyPrefix, err)
}
}()
go func() {
client.WaitForConnection()
waitDelay := initialWaitDelay
for {
response, err := s.getBackends(client, s.keyPrefix)
if err != nil {
if err == context.DeadlineExceeded {
log.Printf("Timeout getting initial list of backends, retry in %s", waitDelay)
} else {
log.Printf("Could not get initial list of backends, retry in %s: %s", waitDelay, err)
}
time.Sleep(waitDelay)
waitDelay = waitDelay * 2
if waitDelay > maxWaitDelay {
waitDelay = maxWaitDelay
}
continue
}
for _, ev := range response.Kvs {
s.EtcdKeyUpdated(client, string(ev.Key), ev.Value)
}
s.initializedFunc()
return
}
}()
}
func (s *backendStorageEtcd) getBackends(client *EtcdClient, keyPrefix string) (*clientv3.GetResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return client.Get(ctx, keyPrefix, clientv3.WithPrefix())
}
func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data []byte) {
var info BackendInformationEtcd
if err := json.Unmarshal(data, &info); err != nil {
log.Printf("Could not decode backend information %s: %s", string(data), err)
return
}
if err := info.CheckValid(); err != nil {
log.Printf("Received invalid backend information %s: %s", string(data), err)
return
}
backend := &Backend{
id: key,
url: info.Url,
secret: []byte(info.Secret),
allowHttp: info.parsedUrl.Scheme == "http",
maxStreamBitrate: info.MaxStreamBitrate,
maxScreenBitrate: info.MaxScreenBitrate,
sessionLimit: info.SessionLimit,
}
host := info.parsedUrl.Host
s.mu.Lock()
defer s.mu.Unlock()
s.keyInfos[key] = &info
entries, found := s.backends[host]
if !found {
// Simple case, first backend for this host
log.Printf("Added backend %s (from %s)", info.Url, key)
s.backends[host] = []*Backend{backend}
statsBackendsCurrent.Inc()
s.wakeupForTesting()
return
}
// Was the backend changed?
replaced := false
for idx, entry := range entries {
if entry.id == key {
log.Printf("Updated backend %s (from %s)", info.Url, key)
entries[idx] = backend
replaced = true
break
}
}
if !replaced {
// New backend, add to list.
log.Printf("Added backend %s (from %s)", info.Url, key)
s.backends[host] = append(entries, backend)
statsBackendsCurrent.Inc()
}
s.wakeupForTesting()
}
func (s *backendStorageEtcd) EtcdKeyDeleted(client *EtcdClient, key string) {
s.mu.Lock()
defer s.mu.Unlock()
info, found := s.keyInfos[key]
if !found {
return
}
delete(s.keyInfos, key)
host := info.parsedUrl.Host
entries, found := s.backends[host]
if !found {
return
}
log.Printf("Removing backend %s (from %s)", info.Url, key)
newEntries := make([]*Backend, 0, len(entries)-1)
for _, entry := range entries {
if entry.id == key {
statsBackendsCurrent.Dec()
continue
}
newEntries = append(newEntries, entry)
}
if len(newEntries) > 0 {
s.backends[host] = newEntries
} else {
delete(s.backends, host)
}
s.wakeupForTesting()
}
func (s *backendStorageEtcd) Close() {
s.etcdClient.RemoveListener(s)
}
func (s *backendStorageEtcd) Reload(config *goconf.ConfigFile) {
// Backend updates are processed through etcd.
}
func (s *backendStorageEtcd) GetCompatBackend() *Backend {
return nil
}
func (s *backendStorageEtcd) GetBackend(u *url.URL) *Backend {
s.mu.RLock()
defer s.mu.RUnlock()
return s.getBackendLocked(u)
}

303
backend_storage_static.go Normal file
View file

@ -0,0 +1,303 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2022 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"log"
"net/url"
"reflect"
"strings"
"github.com/dlintw/goconf"
)
type backendStorageStatic struct {
backendStorageCommon
// Deprecated
allowAll bool
commonSecret []byte
compatBackend *Backend
}
func NewBackendStorageStatic(config *goconf.ConfigFile) (BackendStorage, error) {
allowAll, _ := config.GetBool("backend", "allowall")
allowHttp, _ := config.GetBool("backend", "allowhttp")
commonSecret, _ := config.GetString("backend", "secret")
sessionLimit, err := config.GetInt("backend", "sessionlimit")
if err != nil || sessionLimit < 0 {
sessionLimit = 0
}
backends := make(map[string][]*Backend)
var compatBackend *Backend
numBackends := 0
if allowAll {
log.Println("WARNING: All backend hostnames are allowed, only use for development!")
compatBackend = &Backend{
id: "compat",
secret: []byte(commonSecret),
compat: true,
allowHttp: allowHttp,
sessionLimit: uint64(sessionLimit),
}
if sessionLimit > 0 {
log.Printf("Allow a maximum of %d sessions", sessionLimit)
}
numBackends++
} else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" {
for host, configuredBackends := range getConfiguredHosts(backendIds, config) {
backends[host] = append(backends[host], configuredBackends...)
for _, be := range configuredBackends {
log.Printf("Backend %s added for %s", be.id, be.url)
}
numBackends += len(configuredBackends)
}
} else if allowedUrls, _ := config.GetString("backend", "allowed"); allowedUrls != "" {
// Old-style configuration, only hosts are configured and are using a common secret.
allowMap := make(map[string]bool)
for _, u := range strings.Split(allowedUrls, ",") {
u = strings.TrimSpace(u)
if idx := strings.IndexByte(u, '/'); idx != -1 {
log.Printf("WARNING: Removing path from allowed hostname \"%s\", check your configuration!", u)
u = u[:idx]
}
if u != "" {
allowMap[strings.ToLower(u)] = true
}
}
if len(allowMap) == 0 {
log.Println("WARNING: No backend hostnames are allowed, check your configuration!")
} else {
compatBackend = &Backend{
id: "compat",
secret: []byte(commonSecret),
compat: true,
allowHttp: allowHttp,
sessionLimit: uint64(sessionLimit),
}
hosts := make([]string, 0, len(allowMap))
for host := range allowMap {
hosts = append(hosts, host)
backends[host] = []*Backend{compatBackend}
}
if len(hosts) > 1 {
log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.")
}
log.Printf("Allowed backend hostnames: %s", hosts)
if sessionLimit > 0 {
log.Printf("Allow a maximum of %d sessions", sessionLimit)
}
numBackends++
}
}
statsBackendsCurrent.Add(float64(numBackends))
return &backendStorageStatic{
backendStorageCommon: backendStorageCommon{
backends: backends,
},
allowAll: allowAll,
commonSecret: []byte(commonSecret),
compatBackend: compatBackend,
}, nil
}
func (s *backendStorageStatic) Close() {
}
func (s *backendStorageStatic) RemoveBackendsForHost(host string) {
if oldBackends := s.backends[host]; len(oldBackends) > 0 {
for _, backend := range oldBackends {
log.Printf("Backend %s removed for %s", backend.id, backend.url)
}
statsBackendsCurrent.Sub(float64(len(oldBackends)))
}
delete(s.backends, host)
}
func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend) {
for existingIndex, existingBackend := range s.backends[host] {
found := false
index := 0
for _, newBackend := range backends {
if reflect.DeepEqual(existingBackend, newBackend) { // otherwise we could manually compare the struct members here
found = true
backends = append(backends[:index], backends[index+1:]...)
break
} else if newBackend.id == existingBackend.id {
found = true
s.backends[host][existingIndex] = newBackend
backends = append(backends[:index], backends[index+1:]...)
log.Printf("Backend %s updated for %s", newBackend.id, newBackend.url)
break
}
index++
}
if !found {
removed := s.backends[host][existingIndex]
log.Printf("Backend %s removed for %s", removed.id, removed.url)
s.backends[host] = append(s.backends[host][:existingIndex], s.backends[host][existingIndex+1:]...)
statsBackendsCurrent.Dec()
}
}
s.backends[host] = append(s.backends[host], backends...)
for _, added := range backends {
log.Printf("Backend %s added for %s", added.id, added.url)
}
statsBackendsCurrent.Add(float64(len(backends)))
}
func getConfiguredBackendIDs(backendIds string) (ids []string) {
seen := make(map[string]bool)
for _, id := range strings.Split(backendIds, ",") {
id = strings.TrimSpace(id)
if id == "" {
continue
}
if seen[id] {
continue
}
ids = append(ids, id)
seen[id] = true
}
return ids
}
func getConfiguredHosts(backendIds string, config *goconf.ConfigFile) (hosts map[string][]*Backend) {
hosts = make(map[string][]*Backend)
for _, id := range getConfiguredBackendIDs(backendIds) {
u, _ := config.GetString(id, "url")
if u == "" {
log.Printf("Backend %s is missing or incomplete, skipping", id)
continue
}
if u[len(u)-1] != '/' {
u += "/"
}
parsed, err := url.Parse(u)
if err != nil {
log.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err)
continue
}
if strings.Contains(parsed.Host, ":") && hasStandardPort(parsed) {
parsed.Host = parsed.Hostname()
u = parsed.String()
}
secret, _ := config.GetString(id, "secret")
if u == "" || secret == "" {
log.Printf("Backend %s is missing or incomplete, skipping", id)
continue
}
sessionLimit, err := config.GetInt(id, "sessionlimit")
if err != nil || sessionLimit < 0 {
sessionLimit = 0
}
if sessionLimit > 0 {
log.Printf("Backend %s allows a maximum of %d sessions", id, sessionLimit)
}
maxStreamBitrate, err := config.GetInt(id, "maxstreambitrate")
if err != nil || maxStreamBitrate < 0 {
maxStreamBitrate = 0
}
maxScreenBitrate, err := config.GetInt(id, "maxscreenbitrate")
if err != nil || maxScreenBitrate < 0 {
maxScreenBitrate = 0
}
hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{
id: id,
url: u,
secret: []byte(secret),
allowHttp: parsed.Scheme == "http",
maxStreamBitrate: maxStreamBitrate,
maxScreenBitrate: maxScreenBitrate,
sessionLimit: uint64(sessionLimit),
})