diff --git a/src/signaling/backend_client.go b/src/signaling/backend_client.go index 0fd2fcf..993394a 100644 --- a/src/signaling/backend_client.go +++ b/src/signaling/backend_client.go @@ -107,12 +107,8 @@ func (b *BackendClient) getPool(url *url.URL) (*HttpClientPool, error) { return pool, nil } -func (b *BackendClient) IsCompatBackend() bool { - return b.backends.IsCompatBackend() -} - -func (b *BackendClient) GetCommonSecret() []byte { - return b.backends.GetCommonSecret() +func (b *BackendClient) GetCompatBackend() *Backend { + return b.backends.GetCompatBackend() } func (b *BackendClient) GetBackend(u *url.URL) *Backend { diff --git a/src/signaling/backend_configuration.go b/src/signaling/backend_configuration.go index 3c9cab6..ffc547c 100644 --- a/src/signaling/backend_configuration.go +++ b/src/signaling/backend_configuration.go @@ -52,21 +52,25 @@ type BackendConfiguration struct { backends map[string][]*Backend // Deprecated - whitelistAll bool + allowAll bool commonSecret []byte - compatBackend bool + compatBackend *Backend } func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, error) { - whitelistAll, _ := config.GetBool("backend", "allowall") + allowAll, _ := config.GetBool("backend", "allowall") commonSecret, _ := config.GetString("backend", "secret") backends := make(map[string][]*Backend) - compatBackend := commonSecret != "" - if whitelistAll { + var compatBackend *Backend + if allowAll { log.Println("WARNING: All backend hostnames are allowed, only use for development!") + compatBackend = &Backend{ + id: "compat", + secret: []byte(commonSecret), + compat: true, + } } else if backendIds, _ := config.GetString("backend", "backends"); backendIds != "" { seenIds := make(map[string]bool) - compatBackend = false for _, id := range strings.Split(backendIds, ",") { id = strings.TrimSpace(id) if id == "" { @@ -103,7 +107,7 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, } } else if allowedUrls, _ := config.GetString("backend", "allowed"); allowedUrls != "" { // Old-style configuration, only hosts are configured and are using a common secret. - whitelist := make(map[string]bool) + allowMap := make(map[string]bool) for _, u := range strings.Split(allowedUrls, ",") { u = strings.TrimSpace(u) if idx := strings.IndexByte(u, '/'); idx != -1 { @@ -111,23 +115,22 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, u = u[:idx] } if u != "" { - whitelist[strings.ToLower(u)] = true + allowMap[strings.ToLower(u)] = true } } - if len(whitelist) == 0 { + if len(allowMap) == 0 { log.Println("WARNING: No backend hostnames are allowed, check your configuration!") } else { - hosts := make([]string, 0, len(whitelist)) - for host := range whitelist { + compatBackend = &Backend{ + id: "compat", + secret: []byte(commonSecret), + compat: true, + } + hosts := make([]string, 0, len(allowMap)) + for host := range allowMap { hosts = append(hosts, host) - backends[host] = []*Backend{ - &Backend{ - id: "compat", - secret: []byte(commonSecret), - compat: true, - }, - } + backends[host] = []*Backend{compatBackend} } if len(hosts) > 1 { log.Println("WARNING: Using deprecated backend configuration. Please migrate the \"allowed\" setting to the new \"backends\" configuration.") @@ -139,23 +142,22 @@ func NewBackendConfiguration(config *goconf.ConfigFile) (*BackendConfiguration, return &BackendConfiguration{ backends: backends, - whitelistAll: whitelistAll, + allowAll: allowAll, commonSecret: []byte(commonSecret), compatBackend: compatBackend, }, nil } -func (b *BackendConfiguration) IsCompatBackend() bool { +func (b *BackendConfiguration) GetCompatBackend() *Backend { return b.compatBackend } -func (b *BackendConfiguration) GetCommonSecret() []byte { - return b.commonSecret -} - func (b *BackendConfiguration) GetBackend(u *url.URL) *Backend { entries, found := b.backends[u.Host] if !found { + if b.allowAll { + return b.compatBackend + } return nil } @@ -191,10 +193,6 @@ func (b *BackendConfiguration) IsUrlAllowed(u *url.URL) bool { return false } - if b.whitelistAll { - return true - } - backend := b.GetBackend(u) return backend != nil } @@ -205,10 +203,6 @@ func (b *BackendConfiguration) GetSecret(u *url.URL) []byte { return nil } - if b.whitelistAll { - return b.commonSecret - } - entry := b.GetBackend(u) if entry == nil { return nil diff --git a/src/signaling/backend_server.go b/src/signaling/backend_server.go index 2dbff86..d84a393 100644 --- a/src/signaling/backend_server.go +++ b/src/signaling/backend_server.go @@ -520,11 +520,10 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body } } - var secret []byte if backend == nil { - if b.hub.backend.IsCompatBackend() { + if compatBackend := b.hub.backend.GetCompatBackend(); compatBackend != nil { // Old-style configuration using a single secret for all backends. - secret = b.hub.backend.GetCommonSecret() + backend = compatBackend } else { // Old-style Talk, find backend that created the checksum. // TODO(fancycode): Remove once all supported Talk versions send the backend header. @@ -534,19 +533,15 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body break } } - - if backend == nil { - http.Error(w, "Authentication check failed", http.StatusForbidden) - return - } - - secret = backend.Secret() } - } else { - secret = backend.Secret() + + if backend == nil { + http.Error(w, "Authentication check failed", http.StatusForbidden) + return + } } - if !ValidateBackendChecksum(r, body, secret) { + if !ValidateBackendChecksum(r, body, backend.Secret()) { http.Error(w, "Authentication check failed", http.StatusForbidden) return } diff --git a/src/signaling/hub_test.go b/src/signaling/hub_test.go index c4c3dc0..4e17cbc 100644 --- a/src/signaling/hub_test.go +++ b/src/signaling/hub_test.go @@ -356,6 +356,41 @@ func TestClientHelloWithSpaces(t *testing.T) { } } +func TestClientHelloAllowAll(t *testing.T) { + hub, _, _, server, shutdown := CreateHubForTestWithConfig(t, func(server *httptest.Server) (*goconf.ConfigFile, error) { + config, err := getTestConfig(server) + if err != nil { + return nil, err + } + + config.RemoveOption("backend", "allowed") + config.AddOption("backend", "allowall", "true") + return config, nil + }) + defer shutdown() + + client := NewTestClient(t, server, hub) + defer client.CloseWithBye() + + if err := client.SendHello(testDefaultUserId); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + if hello, err := client.RunUntilHello(ctx); err != nil { + t.Error(err) + } else { + if hello.Hello.UserId != testDefaultUserId { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId, hello.Hello) + } + if hello.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello.Hello) + } + } +} + func TestSessionIdsUnordered(t *testing.T) { hub, _, _, server, shutdown := CreateHubForTest(t) defer shutdown()