Update capabilities if no hello v2 token key is found in cache.

This is necessary to detect updated Talk setups where the signaling server
might have cached capabilities without the v2 token key but the clients
are trying to connect with a hello v2 token. Fetch updated capabilities in
such cases (but throttle to about one invalidation per minute).
This commit is contained in:
Joachim Bauch 2022-08-03 17:15:02 +02:00
parent 184c941f8a
commit cbb6d9ca53
No known key found for this signature in database
GPG Key ID: 77C1D22D53E15F02
5 changed files with 262 additions and 59 deletions

View File

@ -43,8 +43,14 @@ const (
// Cache received capabilities for one hour.
CapabilitiesCacheDuration = time.Hour
// Don't invalidate more than once per minute.
maxInvalidateInterval = time.Minute
)
// Can be overwritten by tests.
var getCapabilitiesNow = time.Now
type capabilitiesEntry struct {
nextUpdate time.Time
capabilities map[string]interface{}
@ -53,16 +59,18 @@ type capabilitiesEntry struct {
type Capabilities struct {
mu sync.RWMutex
version string
pool *HttpClientPool
entries map[string]*capabilitiesEntry
version string
pool *HttpClientPool
entries map[string]*capabilitiesEntry
nextInvalidate map[string]time.Time
}
func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) {
result := &Capabilities{
version: version,
pool: pool,
entries: make(map[string]*capabilitiesEntry),
version: version,
pool: pool,
entries: make(map[string]*capabilitiesEntry),
nextInvalidate: make(map[string]time.Time),
}
return result, nil
@ -86,7 +94,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool
c.mu.RLock()
defer c.mu.RUnlock()
now := time.Now()
now := getCapabilitiesNow()
if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) {
return entry.capabilities, true
}
@ -95,7 +103,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool
}
func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) {
now := time.Now()
now := getCapabilitiesNow()
entry := &capabilitiesEntry{
nextUpdate: now.Add(CapabilitiesCacheDuration),
capabilities: capabilities,
@ -106,11 +114,28 @@ func (c *Capabilities) setCapabilities(key string, capabilities map[string]inter
c.entries[key] = entry
}
func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) {
key := u.String()
func (c *Capabilities) invalidateCapabilities(key string) {
c.mu.Lock()
defer c.mu.Unlock()
now := getCapabilitiesNow()
if entry, found := c.nextInvalidate[key]; found && entry.After(now) {
return
}
delete(c.entries, key)
c.nextInvalidate[key] = now.Add(maxInvalidateInterval)
}
func (c *Capabilities) getKeyForUrl(u *url.URL) string {
key := u.String()
return key
}
func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, bool, error) {
key := c.getKeyForUrl(u)
if caps, found := c.getCapabilities(key); found {
return caps, nil
return caps, true, nil
}
capUrl := *u
@ -128,14 +153,14 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
client, pool, err := c.pool.Get(ctx, &capUrl)
if err != nil {
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
return nil, err
return nil, false, err
}
defer pool.Put(client)
req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
if err != nil {
log.Printf("Could not create request to %s: %s", &capUrl, err)
return nil, err
return nil, false, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("OCS-APIRequest", "true")
@ -143,56 +168,56 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, false, err
}
defer resp.Body.Close()
ct := resp.Header.Get("Content-Type")
if !strings.HasPrefix(ct, "application/json") {
log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status)
return nil, ErrUnsupportedContentType
return nil, false, ErrUnsupportedContentType
}
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Could not read response body from %s: %s", capUrl.String(), err)
return nil, err
return nil, false, err
}
var ocs OcsResponse
if err := json.Unmarshal(body, &ocs); err != nil {
log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err)
return nil, err
return nil, false, err
} else if ocs.Ocs == nil || ocs.Ocs.Data == nil {
log.Printf("Incomplete OCS response %s from %s", string(body), u)
return nil, fmt.Errorf("incomplete OCS response")
return nil, false, fmt.Errorf("incomplete OCS response")
}
var response CapabilitiesResponse
if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil {
log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err)
return nil, err
return nil, false, err
}
capaObj, found := response.Capabilities[AppNameSpreed]
if !found || capaObj == nil {
log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response)
return nil, nil
return nil, false, nil
}
var capa map[string]interface{}
if err := json.Unmarshal(*capaObj, &capa); err != nil {
log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response)
return nil, nil
return nil, false, nil
}
log.Printf("Received capabilities %+v from %s", capa, capUrl.String())
c.setCapabilities(key, capa)
return capa, nil
return capa, false, nil
}
func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool {
caps, err := c.loadCapabilities(ctx, u)
caps, _, err := c.loadCapabilities(ctx, u)
if err != nil {
log.Printf("Could not get capabilities for %s: %s", u, err)
return false
@ -217,80 +242,86 @@ func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, fea
return false
}
func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool) {
caps, err := c.loadCapabilities(ctx, u)
func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool, bool) {
caps, cached, err := c.loadCapabilities(ctx, u)
if err != nil {
log.Printf("Could not get capabilities for %s: %s", u, err)
return nil, false
return nil, cached, false
}
configInterface := caps["config"]
if configInterface == nil {
return nil, false
return nil, cached, false
}
config, ok := configInterface.(map[string]interface{})
if !ok {
log.Printf("Invalid config mapping received from %s: %+v", u, configInterface)
return nil, false
return nil, cached, false
}
groupInterface := config[group]
if groupInterface == nil {
return nil, false
return nil, cached, false
}
groupConfig, ok := groupInterface.(map[string]interface{})
if !ok {
log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface)
return nil, false
return nil, cached, false
}
return groupConfig, true
return groupConfig, cached, true
}
func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool) {
groupConfig, found := c.getConfigGroup(ctx, u, group)
func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool, bool) {
groupConfig, cached, found := c.getConfigGroup(ctx, u, group)
if !found {
return 0, false
return 0, cached, false
}
value, found := groupConfig[key]
if !found {
return 0, false
return 0, cached, false
}
switch value := value.(type) {
case int:
return value, true
return value, cached, true
case float32:
return int(value), true
return int(value), cached, true
case float64:
return int(value), true
return int(value), cached, true
default:
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
}
return 0, false
return 0, cached, false
}
func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool) {
groupConfig, found := c.getConfigGroup(ctx, u, group)
func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool, bool) {
groupConfig, cached, found := c.getConfigGroup(ctx, u, group)
if !found {
return "", false
return "", cached, false
}
value, found := groupConfig[key]
if !found {
return "", false
return "", cached, false
}
switch value := value.(type) {
case string:
return value, true
return value, cached, true
default:
log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value)
}
return "", false
return "", cached, false
}
func (c *Capabilities) InvalidateCapabilities(u *url.URL) {
key := c.getKeyForUrl(u)
c.invalidateCapabilities(key)
}

View File

@ -28,12 +28,14 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/mux"
)
func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse)) (*url.URL, *Capabilities) {
pool, err := NewHttpClientPool(1, false)
if err != nil {
t.Fatal(err)
@ -84,6 +86,10 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
},
}
if callback != nil {
callback(response)
}
data, err := json.Marshal(response)
if err != nil {
t.Errorf("Could not marshal %+v: %s", response, err)
@ -110,6 +116,19 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
return u, capabilities
}
func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) {
return NewCapabilitiesForTestWithCallback(t, nil)
}
func SetCapabilitiesGetNow(t *testing.T, f func() time.Time) {
old := getCapabilitiesNow
t.Cleanup(func() {
getCapabilitiesNow = old
})
getCapabilitiesNow = f
}
func TestCapabilities(t *testing.T) {
url, capabilities := NewCapabilitiesForTest(t)
@ -124,34 +143,122 @@ func TestCapabilities(t *testing.T) {
}
expectedString := "bar"
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if !cached {
t.Errorf("expected cached response")
}
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found {
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found {
t.Errorf("should not have found value for \"baz\", got %s", value)
} else if !cached {
t.Errorf("expected cached response")
}
if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found {
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found {
t.Errorf("should not have found value for \"invalid\", got %s", value)
} else if !cached {
t.Errorf("expected cached response")
}
if value, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found {
if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found {
t.Errorf("should not have found value for \"baz\", got %s", value)
} else if !cached {
t.Errorf("expected cached response")
}
expectedInt := 42
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found {
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found {
t.Error("could not find value for \"baz\"")
} else if value != expectedInt {
t.Errorf("expected value %d, got %d", expectedInt, value)
} else if !cached {
t.Errorf("expected cached response")
}
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found {
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found {
t.Errorf("should not have found value for \"foo\", got %d", value)
} else if !cached {
t.Errorf("expected cached response")
}
if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found {
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found {
t.Errorf("should not have found value for \"invalid\", got %d", value)
} else if !cached {
t.Errorf("expected cached response")
}
if value, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found {
if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found {
t.Errorf("should not have found value for \"baz\", got %d", value)
} else if !cached {
t.Errorf("expected cached response")
}
}
func TestInvalidateCapabilities(t *testing.T) {
var called uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) {
atomic.AddUint32(&called, 1)
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
}
if value := atomic.LoadUint32(&called); value != 1 {
t.Errorf("expected called %d, got %d", 1, value)
}
// Invalidating will cause the capabilities to be reloaded.
capabilities.InvalidateCapabilities(url)
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
}
if value := atomic.LoadUint32(&called); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
// Invalidating is throttled to about once per minute.
capabilities.InvalidateCapabilities(url)
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if !cached {
t.Errorf("expected cached response")
}
if value := atomic.LoadUint32(&called); value != 2 {
t.Errorf("expected called %d, got %d", 2, value)
}
// At a later time, invalidating can be done again.
SetCapabilitiesGetNow(t, func() time.Time {
return time.Now().Add(2 * time.Minute)
})
capabilities.InvalidateCapabilities(url)
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found {
t.Error("could not find value for \"foo\"")
} else if value != expectedString {
t.Errorf("expected value %s, got %s", expectedString, value)
} else if cached {
t.Errorf("expected direct response")
}
if value := atomic.LoadUint32(&called); value != 3 {
t.Errorf("expected called %d, got %d", 3, value)
}
}

12
hub.go
View File

@ -1063,9 +1063,17 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend,
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
defer cancel()
keyData, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
keyData, cached, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
if !found {
return nil, fmt.Errorf("No key found for issuer")
if cached {
// The Nextcloud instance might just have enabled JWT but we probably use
// the cached capabilities without the public key. Make sure to re-fetch.
h.backend.capabilities.InvalidateCapabilities(url)
keyData, _, found = h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
}
if !found {
return nil, fmt.Errorf("No key found for issuer")
}
}
key, err := loadKeyFunc([]byte(keyData))

View File

@ -697,7 +697,11 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) {
if strings.Contains(t.Name(), "MultiRoom") {
signaling[ConfigKeySessionPingLimit] = 2
}
if strings.Contains(t.Name(), "V2") {
useV2 := true
if os.Getenv("SKIP_V2_CAPABILITIES") != "" {
useV2 = false
}
if strings.Contains(t.Name(), "V2") && useV2 {
key := getPublicAuthToken(t)
public, err := x509.MarshalPKIXPublicKey(key)
if err != nil {
@ -1060,6 +1064,59 @@ func TestClientHelloV2_ExpiresAtMissing(t *testing.T) {
}
}
func TestClientHelloV2_CachedCapabilities(t *testing.T) {
for _, algo := range testHelloV2Algorithms {
t.Run(algo, func(t *testing.T) {
hub, _, _, server := CreateHubForTest(t)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
// Simulate old-style Nextcloud without capabilities for Hello V2.
t.Setenv("SKIP_V2_CAPABILITIES", "1")
client1 := NewTestClient(t, server, hub)
defer client1.CloseWithBye()
if err := client1.SendHelloV1(testDefaultUserId + "1"); err != nil {
t.Fatal(err)
}
hello1, err := client1.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
if hello1.Hello.UserId != testDefaultUserId+"1" {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"1", hello1.Hello)
}
if hello1.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello1.Hello)
}
// Simulate updated Nextcloud with capabilities for Hello V2.
t.Setenv("SKIP_V2_CAPABILITIES", "")
client2 := NewTestClient(t, server, hub)
defer client2.CloseWithBye()
if err := client2.SendHelloV2(testDefaultUserId + "2"); err != nil {
t.Fatal(err)
}
hello2, err := client2.RunUntilHello(ctx)
if err != nil {
t.Fatal(err)
}
if hello2.Hello.UserId != testDefaultUserId+"2" {
t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello2.Hello)
}
if hello2.Hello.SessionId == "" {
t.Errorf("Expected session id, got %+v", hello2.Hello)
}
})
}
}
func TestClientHelloWithSpaces(t *testing.T) {
hub, _, _, server := CreateHubForTest(t)

View File

@ -119,7 +119,7 @@ func (p *RoomPing) publishEntries(entries *pingEntries, timeout time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
limit, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
limit, _, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
if !found || limit <= 0 {
// Limit disabled while waiting for the next iteration, fallback to sending
// one request per room.
@ -188,7 +188,7 @@ func (p *RoomPing) sendPingsCombined(url *url.URL, entries []BackendPingEntry, l
}
func (p *RoomPing) SendPings(ctx context.Context, room *Room, url *url.URL, entries []BackendPingEntry) error {
limit, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
limit, _, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit)
if !found || limit <= 0 {
// Old-style Nextcloud or session limit not configured. Perform one request
// per room. Don't queue to avoid sending all ping requests to old-style