mirror of
https://github.com/strukturag/nextcloud-spreed-signaling
synced 2024-05-17 04:56:33 +02:00
Update capabilities if no hello v2 token key is found in cache.
This is necessary to detect updated Talk setups where the signaling server might have cached capabilities without the v2 token key but the clients are trying to connect with a hello v2 token. Fetch updated capabilities in such cases (but throttle to about one invalidation per minute).
This commit is contained in:
parent
184c941f8a
commit
cbb6d9ca53
121
capabilities.go
121
capabilities.go
|
@ -43,8 +43,14 @@ const (
|
||||||
|
|
||||||
// Cache received capabilities for one hour.
|
// Cache received capabilities for one hour.
|
||||||
CapabilitiesCacheDuration = time.Hour
|
CapabilitiesCacheDuration = time.Hour
|
||||||
|
|
||||||
|
// Don't invalidate more than once per minute.
|
||||||
|
maxInvalidateInterval = time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Can be overwritten by tests.
|
||||||
|
var getCapabilitiesNow = time.Now
|
||||||
|
|
||||||
type capabilitiesEntry struct {
|
type capabilitiesEntry struct {
|
||||||
nextUpdate time.Time
|
nextUpdate time.Time
|
||||||
capabilities map[string]interface{}
|
capabilities map[string]interface{}
|
||||||
|
@ -53,16 +59,18 @@ type capabilitiesEntry struct {
|
||||||
type Capabilities struct {
|
type Capabilities struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
version string
|
version string
|
||||||
pool *HttpClientPool
|
pool *HttpClientPool
|
||||||
entries map[string]*capabilitiesEntry
|
entries map[string]*capabilitiesEntry
|
||||||
|
nextInvalidate map[string]time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) {
|
func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) {
|
||||||
result := &Capabilities{
|
result := &Capabilities{
|
||||||
version: version,
|
version: version,
|
||||||
pool: pool,
|
pool: pool,
|
||||||
entries: make(map[string]*capabilitiesEntry),
|
entries: make(map[string]*capabilitiesEntry),
|
||||||
|
nextInvalidate: make(map[string]time.Time),
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
@ -86,7 +94,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
defer c.mu.RUnlock()
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
now := time.Now()
|
now := getCapabilitiesNow()
|
||||||
if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) {
|
if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) {
|
||||||
return entry.capabilities, true
|
return entry.capabilities, true
|
||||||
}
|
}
|
||||||
|
@ -95,7 +103,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) {
|
func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) {
|
||||||
now := time.Now()
|
now := getCapabilitiesNow()
|
||||||
entry := &capabilitiesEntry{
|
entry := &capabilitiesEntry{
|
||||||
nextUpdate: now.Add(CapabilitiesCacheDuration),
|
nextUpdate: now.Add(CapabilitiesCacheDuration),
|
||||||
capabilities: capabilities,
|
capabilities: capabilities,
|
||||||
|
@ -106,11 +114,28 @@ func (c *Capabilities) setCapabilities(key string, capabilities map[string]inter
|
||||||
c.entries[key] = entry
|
c.entries[key] = entry
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) {
|
func (c *Capabilities) invalidateCapabilities(key string) {
|
||||||
key := u.String()
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
now := getCapabilitiesNow()
|
||||||
|
if entry, found := c.nextInvalidate[key]; found && entry.After(now) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(c.entries, key)
|
||||||
|
c.nextInvalidate[key] = now.Add(maxInvalidateInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Capabilities) getKeyForUrl(u *url.URL) string {
|
||||||
|
key := u.String()
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, bool, error) {
|
||||||
|
key := c.getKeyForUrl(u)
|
||||||
if caps, found := c.getCapabilities(key); found {
|
if caps, found := c.getCapabilities(key); found {
|
||||||
return caps, nil
|
return caps, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
capUrl := *u
|
capUrl := *u
|
||||||
|
@ -128,14 +153,14 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
|
||||||
client, pool, err := c.pool.Get(ctx, &capUrl)
|
client, pool, err := c.pool.Get(ctx, &capUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
|
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
defer pool.Put(client)
|
defer pool.Put(client)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not create request to %s: %s", &capUrl, err)
|
log.Printf("Could not create request to %s: %s", &capUrl, err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
req.Header.Set("OCS-APIRequest", "true")
|
req.Header.Set("OCS-APIRequest", "true")
|
||||||
|
@ -143,56 +168,56 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
ct := resp.Header.Get("Content-Type")
|
ct := resp.Header.Get("Content-Type")
|
||||||
if !strings.HasPrefix(ct, "application/json") {
|
if !strings.HasPrefix(ct, "application/json") {
|
||||||
log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status)
|
log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status)
|
||||||
return nil, ErrUnsupportedContentType
|
return nil, false, ErrUnsupportedContentType
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not read response body from %s: %s", capUrl.String(), err)
|
log.Printf("Could not read response body from %s: %s", capUrl.String(), err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var ocs OcsResponse
|
var ocs OcsResponse
|
||||||
if err := json.Unmarshal(body, &ocs); err != nil {
|
if err := json.Unmarshal(body, &ocs); err != nil {
|
||||||
log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err)
|
log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
} else if ocs.Ocs == nil || ocs.Ocs.Data == nil {
|
} else if ocs.Ocs == nil || ocs.Ocs.Data == nil {
|
||||||
log.Printf("Incomplete OCS response %s from %s", string(body), u)
|
log.Printf("Incomplete OCS response %s from %s", string(body), u)
|
||||||
return nil, fmt.Errorf("incomplete OCS response")
|
return nil, false, fmt.Errorf("incomplete OCS response")
|
||||||
}
|
}
|
||||||
|
|
||||||
var response CapabilitiesResponse
|
var response CapabilitiesResponse
|
||||||
if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil {
|
if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil {
|
||||||
log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err)
|
log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err)
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
capaObj, found := response.Capabilities[AppNameSpreed]
|
capaObj, found := response.Capabilities[AppNameSpreed]
|
||||||
if !found || capaObj == nil {
|
if !found || capaObj == nil {
|
||||||
log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response)
|
log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response)
|
||||||
return nil, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var capa map[string]interface{}
|
var capa map[string]interface{}
|
||||||
if err := json.Unmarshal(*capaObj, &capa); err != nil {
|
if err := json.Unmarshal(*capaObj, &capa); err != nil {
|
||||||
log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response)
|
log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response)
|
||||||
return nil, nil
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Received capabilities %+v from %s", capa, capUrl.String())
|
log.Printf("Received capabilities %+v from %s", capa, capUrl.String())
|
||||||
c.setCapabilities(key, capa)
|
c.setCapabilities(key, capa)
|
||||||
return capa, nil
|
return capa, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool {
|
func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool {
|
||||||
caps, err := c.loadCapabilities(ctx, u)
|
caps, _, err := c.loadCapabilities(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not get capabilities for %s: %s", u, err)
|
log.Printf("Could not get capabilities for %s: %s", u, err)
|
||||||
return false
|
return false
|
||||||
|
@ -217,80 +242,86 @@ func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, fea
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool) {
|
func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool, bool) {
|
||||||
caps, err := c.loadCapabilities(ctx, u)
|
caps, cached, err := c.loadCapabilities(ctx, u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not get capabilities for %s: %s", u, err)
|
log.Printf("Could not get capabilities for %s: %s", u, err)
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
configInterface := caps["config"]
|
configInterface := caps["config"]
|
||||||
if configInterface == nil {
|
if configInterface == nil {
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
config, ok := configInterface.(map[string]interface{})
|
config, ok := configInterface.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Invalid config mapping received from %s: %+v", u, configInterface)
|
log.Printf("Invalid config mapping received from %s: %+v", u, configInterface)
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
groupInterface := config[group]
|
groupInterface := config[group]
|
||||||
if groupInterface == nil {
|
if groupInterface == nil {
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
groupConfig, ok := groupInterface.(map[string]interface{})
|
groupConfig, ok := groupInterface.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface)
|
log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface)
|
||||||
return nil, false
|
return nil, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return groupConfig, true
|
return groupConfig, cached, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool) {
|
func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool, bool) {
|
||||||
groupConfig, found := c.getConfigGroup(ctx, u, group)
|
groupConfig, cached, found := c.getConfigGroup(ctx, u, group)
|
||||||
if !found {
|
if !found {
|
||||||
return 0, false
|
return 0, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
value, found := groupConfig[key]
|
value, found := groupConfig[key]
|
||||||
if !found {
|
if !found {
|
||||||
return 0, false
|
return 0, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch value := value.(type) {
|
switch value := value.(type) {
|
||||||
case int:
|
case int:
|
||||||
return value, true
|
return value, cached, true
|
||||||
case float32:
|
case float32:
|
||||||
return int(value), true
|
return int(value), cached, true
|
||||||
case float64:
|
case float64:
|
||||||
return int(value), true
|
return int(value), cached, true
|
||||||
default:
|
default:
|
||||||
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
|
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, false
|
return 0, cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool) {
|
func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool, bool) {
|
||||||
groupConfig, found := c.getConfigGroup(ctx, u, group)
|
groupConfig, cached, found := c.getConfigGroup(ctx, u, group)
|
||||||
if !found {
|
if !found {
|
||||||
return "", false
|
return "", cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
value, found := groupConfig[key]
|
value, found := groupConfig[key]
|
||||||
if !found {
|
if !found {
|
||||||
return "", false
|
return "", cached, false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch value := value.(type) {
|
switch value := value.(type) {
|
||||||
case string:
|
case string:
|
||||||
return value, true
|
return value, cached, true
|
||||||
default:
|
default:
|
||||||
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
|
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", false
|
return "", cached, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Capabilities) InvalidateCapabilities(u *url.URL) {
|
||||||
|
key := c.getKeyForUrl(u)
|
||||||
|
|
||||||
|
c.invalidateCapabilities(key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,12 +28,14 @@ import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
|
func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse)) (*url.URL, *Capabilities) {
|
||||||
pool, err := NewHttpClientPool(1, false)
|
pool, err := NewHttpClientPool(1, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -84,6 +86,10 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if callback != nil {
|
||||||
|
callback(response)
|
||||||
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(response)
|
data, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Could not marshal %+v: %s", response, err)
|
t.Errorf("Could not marshal %+v: %s", response, err)
|
||||||
|
@ -110,6 +116,19 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
|
||||||
return u, capabilities
|
return u, capabilities
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
|
||||||
|
return NewCapabilitiesForTestWithCallback(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetCapabilitiesGetNow(t *testing.T, f func() time.Time) {
|
||||||
|
old := getCapabilitiesNow
|
||||||
|
t.Cleanup(func() {
|
||||||
|
getCapabilitiesNow = old
|
||||||
|
})
|
||||||
|
|
||||||
|
getCapabilitiesNow = f
|
||||||
|
}
|
||||||
|
|
||||||
func TestCapabilities(t *testing.T) {
|
func TestCapabilities(t *testing.T) {
|
||||||
url, capabilities := NewCapabilitiesForTest(t)
|
url, capabilities := NewCapabilitiesForTest(t)
|
||||||
|
|
||||||
|
@ -124,34 +143,122 @@ func TestCapabilities(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedString := "bar"
|
expectedString := "bar"
|
||||||
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
t.Error("could not find value for \"foo\"")
|
t.Error("could not find value for \"foo\"")
|
||||||
} else if value != expectedString {
|
} else if value != expectedString {
|
||||||
t.Errorf("expected value %s, got %s", expectedString, value)
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found {
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found {
|
||||||
t.Errorf("should not have found value for \"baz\", got %s", value)
|
t.Errorf("should not have found value for \"baz\", got %s", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found {
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found {
|
||||||
t.Errorf("should not have found value for \"invalid\", got %s", value)
|
t.Errorf("should not have found value for \"invalid\", got %s", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found {
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found {
|
||||||
t.Errorf("should not have found value for \"baz\", got %s", value)
|
t.Errorf("should not have found value for \"baz\", got %s", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedInt := 42
|
expectedInt := 42
|
||||||
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found {
|
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found {
|
||||||
t.Error("could not find value for \"baz\"")
|
t.Error("could not find value for \"baz\"")
|
||||||
} else if value != expectedInt {
|
} else if value != expectedInt {
|
||||||
t.Errorf("expected value %d, got %d", expectedInt, value)
|
t.Errorf("expected value %d, got %d", expectedInt, value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found {
|
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found {
|
||||||
t.Errorf("should not have found value for \"foo\", got %d", value)
|
t.Errorf("should not have found value for \"foo\", got %d", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found {
|
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found {
|
||||||
t.Errorf("should not have found value for \"invalid\", got %d", value)
|
t.Errorf("should not have found value for \"invalid\", got %d", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
}
|
}
|
||||||
if value, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found {
|
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found {
|
||||||
t.Errorf("should not have found value for \"baz\", got %d", value)
|
t.Errorf("should not have found value for \"baz\", got %d", value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidateCapabilities(t *testing.T) {
|
||||||
|
var called uint32
|
||||||
|
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) {
|
||||||
|
atomic.AddUint32(&called, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
expectedString := "bar"
|
||||||
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
|
t.Error("could not find value for \"foo\"")
|
||||||
|
} else if value != expectedString {
|
||||||
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if cached {
|
||||||
|
t.Errorf("expected direct response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := atomic.LoadUint32(&called); value != 1 {
|
||||||
|
t.Errorf("expected called %d, got %d", 1, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidating will cause the capabilities to be reloaded.
|
||||||
|
capabilities.InvalidateCapabilities(url)
|
||||||
|
|
||||||
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
|
t.Error("could not find value for \"foo\"")
|
||||||
|
} else if value != expectedString {
|
||||||
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if cached {
|
||||||
|
t.Errorf("expected direct response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := atomic.LoadUint32(&called); value != 2 {
|
||||||
|
t.Errorf("expected called %d, got %d", 2, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidating is throttled to about once per minute.
|
||||||
|
capabilities.InvalidateCapabilities(url)
|
||||||
|
|
||||||
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
|
t.Error("could not find value for \"foo\"")
|
||||||
|
} else if value != expectedString {
|
||||||
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if !cached {
|
||||||
|
t.Errorf("expected cached response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := atomic.LoadUint32(&called); value != 2 {
|
||||||
|
t.Errorf("expected called %d, got %d", 2, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// At a later time, invalidating can be done again.
|
||||||
|
SetCapabilitiesGetNow(t, func() time.Time {
|
||||||
|
return time.Now().Add(2 * time.Minute)
|
||||||
|
})
|
||||||
|
|
||||||
|
capabilities.InvalidateCapabilities(url)
|
||||||
|
|
||||||
|
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
|
||||||
|
t.Error("could not find value for \"foo\"")
|
||||||
|
} else if value != expectedString {
|
||||||
|
t.Errorf("expected value %s, got %s", expectedString, value)
|
||||||
|
} else if cached {
|
||||||
|
t.Errorf("expected direct response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := atomic.LoadUint32(&called); value != 3 {
|
||||||
|
t.Errorf("expected called %d, got %d", 3, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
12
hub.go
12
hub.go
|
@ -1063,9 +1063,17 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend,
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
keyData, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
|
keyData, cached, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
|
||||||
if !found {
|
if !found {
|
||||||
return nil, fmt.Errorf("No key found for issuer")
|
if cached {
|
||||||
|
// The Nextcloud instance might just have enabled JWT but we probably use
|
||||||
|
// the cached capabilities without the public key. Make sure to re-fetch.
|
||||||
|
h.backend.capabilities.InvalidateCapabilities(url)
|
||||||
|
keyData, _, found = h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("No key found for issuer")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := loadKeyFunc([]byte(keyData))
|
key, err := loadKeyFunc([]byte(keyData))
|
||||||
|
|
59
hub_test.go
59
hub_test.go
|
@ -697,7 +697,11 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) {
|
||||||
if strings.Contains(t.Name(), "MultiRoom") {
|
if strings.Contains(t.Name(), "MultiRoom") {
|
||||||
signaling[ConfigKeySessionPingLimit] = 2
|
signaling[ConfigKeySessionPingLimit] = 2
|
||||||
}
|
}
|
||||||
if strings.Contains(t.Name(), "V2") {
|
useV2 := true
|
||||||
|
if os.Getenv("SKIP_V2_CAPABILITIES") != "" {
|
||||||
|
useV2 = false
|
||||||
|
}
|
||||||
|
if strings.Contains(t.Name(), "V2") && useV2 {
|
||||||
key := getPublicAuthToken(t)
|
key := getPublicAuthToken(t)
|
||||||
public, err := x509.MarshalPKIXPublicKey(key)
|
public, err := x509.MarshalPKIXPublicKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1060,6 +1064,59 @@ func TestClientHelloV2_ExpiresAtMissing(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientHelloV2_CachedCapabilities(t *testing.T) {
|
||||||
|
for _, algo := range testHelloV2Algorithms {
|
||||||
|
t.Run(algo, func(t *testing.T) {
|
||||||
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Simulate old-style Nextcloud without capabilities for Hello V2.
|
||||||
|
t.Setenv("SKIP_V2_CAPABILITIES", "1")
|
||||||
|
|
||||||
|
client1 := NewTestClient(t, server, hub)
|
||||||
|
defer client1.CloseWithBye()
|
||||||
|
|
||||||
|
if err := client1.SendHelloV1(testDefaultUserId + "1"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hello1, err := client1.RunUntilHello(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if hello1.Hello.UserId != testDefaultUserId+"1" {
|
||||||
|
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"1", hello1.Hello)
|
||||||
|
}
|
||||||
|
if hello1.Hello.SessionId == "" {
|
||||||
|
t.Errorf("Expected session id, got %+v", hello1.Hello)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate updated Nextcloud with capabilities for Hello V2.
|
||||||
|
t.Setenv("SKIP_V2_CAPABILITIES", "")
|
||||||
|
|
||||||
|
client2 := NewTestClient(t, server, hub)
|
||||||
|
defer client2.CloseWithBye()
|
||||||
|
|
||||||
|
if err := client2.SendHelloV2(testDefaultUserId + "2"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hello2, err := client2.RunUntilHello(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if hello2.Hello.UserId != testDefaultUserId+"2" {
|
||||||
|
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello2.Hello)
|
||||||
|
}
|
||||||
|
if hello2.Hello.SessionId == "" {
|
||||||
|
t.Errorf("Expected session id, got %+v", hello2.Hello)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClientHelloWithSpaces(t *testing.T) {
|
func TestClientHelloWithSpaces(t *testing.T) {
|
||||||
hub, _, _, server := CreateHubForTest(t)
|
hub, _, _, server := CreateHubForTest(t)
|
||||||
|
|
||||||
|
|
|
@ -119,7 +119,7 @@ func (p *RoomPing) publishEntries(entries *pingEntries, timeout time.Duration) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
limit, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
|
limit, _, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
|
||||||
if !found || limit <= 0 {
|
if !found || limit <= 0 {
|
||||||
// Limit disabled while waiting for the next iteration, fallback to sending
|
// Limit disabled while waiting for the next iteration, fallback to sending
|
||||||
// one request per room.
|
// one request per room.
|
||||||
|
@ -188,7 +188,7 @@ func (p *RoomPing) sendPingsCombined(url *url.URL, entries []BackendPingEntry, l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RoomPing) SendPings(ctx context.Context, room *Room, url *url.URL, entries []BackendPingEntry) error {
|
func (p *RoomPing) SendPings(ctx context.Context, room *Room, url *url.URL, entries []BackendPingEntry) error {
|
||||||
limit, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
|
limit, _, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
|
||||||
if !found || limit <= 0 {
|
if !found || limit <= 0 {
|
||||||
// Old-style Nextcloud or session limit not configured. Perform one request
|
// Old-style Nextcloud or session limit not configured. Perform one request
|
||||||
// per room. Don't queue to avoid sending all ping requests to old-style
|
// per room. Don't queue to avoid sending all ping requests to old-style
|
||||||
|
|
Loading…
Reference in a new issue