diff --git a/capabilities.go b/capabilities.go index 491bd7a..12914bc 100644 --- a/capabilities.go +++ b/capabilities.go @@ -43,8 +43,14 @@ const ( // Cache received capabilities for one hour. CapabilitiesCacheDuration = time.Hour + + // Don't invalidate more than once per minute. + maxInvalidateInterval = time.Minute ) +// Can be overwritten by tests. +var getCapabilitiesNow = time.Now + type capabilitiesEntry struct { nextUpdate time.Time capabilities map[string]interface{} @@ -53,16 +59,18 @@ type capabilitiesEntry struct { type Capabilities struct { mu sync.RWMutex - version string - pool *HttpClientPool - entries map[string]*capabilitiesEntry + version string + pool *HttpClientPool + entries map[string]*capabilitiesEntry + nextInvalidate map[string]time.Time } func NewCapabilities(version string, pool *HttpClientPool) (*Capabilities, error) { result := &Capabilities{ - version: version, - pool: pool, - entries: make(map[string]*capabilitiesEntry), + version: version, + pool: pool, + entries: make(map[string]*capabilitiesEntry), + nextInvalidate: make(map[string]time.Time), } return result, nil @@ -86,7 +94,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool c.mu.RLock() defer c.mu.RUnlock() - now := time.Now() + now := getCapabilitiesNow() if entry, found := c.entries[key]; found && entry.nextUpdate.After(now) { return entry.capabilities, true } @@ -95,7 +103,7 @@ func (c *Capabilities) getCapabilities(key string) (map[string]interface{}, bool } func (c *Capabilities) setCapabilities(key string, capabilities map[string]interface{}) { - now := time.Now() + now := getCapabilitiesNow() entry := &capabilitiesEntry{ nextUpdate: now.Add(CapabilitiesCacheDuration), capabilities: capabilities, @@ -106,11 +114,28 @@ func (c *Capabilities) setCapabilities(key string, capabilities map[string]inter c.entries[key] = entry } -func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, error) { - key := u.String() +func (c *Capabilities) invalidateCapabilities(key string) { + c.mu.Lock() + defer c.mu.Unlock() + now := getCapabilitiesNow() + if entry, found := c.nextInvalidate[key]; found && entry.After(now) { + return + } + + delete(c.entries, key) + c.nextInvalidate[key] = now.Add(maxInvalidateInterval) +} + +func (c *Capabilities) getKeyForUrl(u *url.URL) string { + key := u.String() + return key +} + +func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[string]interface{}, bool, error) { + key := c.getKeyForUrl(u) if caps, found := c.getCapabilities(key); found { - return caps, nil + return caps, true, nil } capUrl := *u @@ -128,14 +153,14 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st client, pool, err := c.pool.Get(ctx, &capUrl) if err != nil { log.Printf("Could not get client for host %s: %s", capUrl.Host, err) - return nil, err + return nil, false, err } defer pool.Put(client) req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil) if err != nil { log.Printf("Could not create request to %s: %s", &capUrl, err) - return nil, err + return nil, false, err } req.Header.Set("Accept", "application/json") req.Header.Set("OCS-APIRequest", "true") @@ -143,56 +168,56 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st resp, err := client.Do(req) if err != nil { - return nil, err + return nil, false, err } defer resp.Body.Close() ct := resp.Header.Get("Content-Type") if !strings.HasPrefix(ct, "application/json") { log.Printf("Received unsupported content-type from %s: %s (%s)", capUrl.String(), ct, resp.Status) - return nil, ErrUnsupportedContentType + return nil, false, ErrUnsupportedContentType } body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("Could not read response body from %s: %s", capUrl.String(), err) - return nil, err + return nil, false, err } var ocs OcsResponse if err := json.Unmarshal(body, &ocs); err != nil { log.Printf("Could not decode OCS response %s from %s: %s", string(body), capUrl.String(), err) - return nil, err + return nil, false, err } else if ocs.Ocs == nil || ocs.Ocs.Data == nil { log.Printf("Incomplete OCS response %s from %s", string(body), u) - return nil, fmt.Errorf("incomplete OCS response") + return nil, false, fmt.Errorf("incomplete OCS response") } var response CapabilitiesResponse if err := json.Unmarshal(*ocs.Ocs.Data, &response); err != nil { log.Printf("Could not decode OCS response body %s from %s: %s", string(*ocs.Ocs.Data), capUrl.String(), err) - return nil, err + return nil, false, err } capaObj, found := response.Capabilities[AppNameSpreed] if !found || capaObj == nil { log.Printf("No capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, nil + return nil, false, nil } var capa map[string]interface{} if err := json.Unmarshal(*capaObj, &capa); err != nil { log.Printf("Unsupported capabilities received for app spreed from %s: %+v", capUrl.String(), response) - return nil, nil + return nil, false, nil } log.Printf("Received capabilities %+v from %s", capa, capUrl.String()) c.setCapabilities(key, capa) - return capa, nil + return capa, false, nil } func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool { - caps, err := c.loadCapabilities(ctx, u) + caps, _, err := c.loadCapabilities(ctx, u) if err != nil { log.Printf("Could not get capabilities for %s: %s", u, err) return false @@ -217,80 +242,86 @@ func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, fea return false } -func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool) { - caps, err := c.loadCapabilities(ctx, u) +func (c *Capabilities) getConfigGroup(ctx context.Context, u *url.URL, group string) (map[string]interface{}, bool, bool) { + caps, cached, err := c.loadCapabilities(ctx, u) if err != nil { log.Printf("Could not get capabilities for %s: %s", u, err) - return nil, false + return nil, cached, false } configInterface := caps["config"] if configInterface == nil { - return nil, false + return nil, cached, false } config, ok := configInterface.(map[string]interface{}) if !ok { log.Printf("Invalid config mapping received from %s: %+v", u, configInterface) - return nil, false + return nil, cached, false } groupInterface := config[group] if groupInterface == nil { - return nil, false + return nil, cached, false } groupConfig, ok := groupInterface.(map[string]interface{}) if !ok { log.Printf("Invalid group mapping \"%s\" received from %s: %+v", group, u, groupInterface) - return nil, false + return nil, cached, false } - return groupConfig, true + return groupConfig, cached, true } -func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool) { - groupConfig, found := c.getConfigGroup(ctx, u, group) +func (c *Capabilities) GetIntegerConfig(ctx context.Context, u *url.URL, group, key string) (int, bool, bool) { + groupConfig, cached, found := c.getConfigGroup(ctx, u, group) if !found { - return 0, false + return 0, cached, false } value, found := groupConfig[key] if !found { - return 0, false + return 0, cached, false } switch value := value.(type) { case int: - return value, true + return value, cached, true case float32: - return int(value), true + return int(value), cached, true case float64: - return int(value), true + return int(value), cached, true default: log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) } - return 0, false + return 0, cached, false } -func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool) { - groupConfig, found := c.getConfigGroup(ctx, u, group) +func (c *Capabilities) GetStringConfig(ctx context.Context, u *url.URL, group, key string) (string, bool, bool) { + groupConfig, cached, found := c.getConfigGroup(ctx, u, group) if !found { - return "", false + return "", cached, false } value, found := groupConfig[key] if !found { - return "", false + return "", cached, false } switch value := value.(type) { case string: - return value, true + return value, cached, true default: log.Printf("Invalid config value for \"%s\" received from %s: %+v", key, u, value) } - return "", false + return "", cached, false +} + +func (c *Capabilities) InvalidateCapabilities(u *url.URL) { + key := c.getKeyForUrl(u) + + c.invalidateCapabilities(key) } diff --git a/capabilities_test.go b/capabilities_test.go index 19eb087..22f653b 100644 --- a/capabilities_test.go +++ b/capabilities_test.go @@ -28,12 +28,14 @@ import ( "net/http/httptest" "net/url" "strings" + "sync/atomic" "testing" + "time" "github.com/gorilla/mux" ) -func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { +func NewCapabilitiesForTestWithCallback(t *testing.T, callback func(*CapabilitiesResponse)) (*url.URL, *Capabilities) { pool, err := NewHttpClientPool(1, false) if err != nil { t.Fatal(err) @@ -84,6 +86,10 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { }, } + if callback != nil { + callback(response) + } + data, err := json.Marshal(response) if err != nil { t.Errorf("Could not marshal %+v: %s", response, err) @@ -110,6 +116,19 @@ func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { return u, capabilities } +func NewCapabilitiesForTest(t *testing.T) (*url.URL, *Capabilities) { + return NewCapabilitiesForTestWithCallback(t, nil) +} + +func SetCapabilitiesGetNow(t *testing.T, f func() time.Time) { + old := getCapabilitiesNow + t.Cleanup(func() { + getCapabilitiesNow = old + }) + + getCapabilitiesNow = f +} + func TestCapabilities(t *testing.T) { url, capabilities := NewCapabilitiesForTest(t) @@ -124,34 +143,122 @@ func TestCapabilities(t *testing.T) { } expectedString := "bar" - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { t.Error("could not find value for \"foo\"") } else if value != expectedString { t.Errorf("expected value %s, got %s", expectedString, value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "baz"); found { t.Errorf("should not have found value for \"baz\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "invalid"); found { t.Errorf("should not have found value for \"invalid\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found { + if value, cached, found := capabilities.GetStringConfig(ctx, url, "invalid", "foo"); found { t.Errorf("should not have found value for \"baz\", got %s", value) + } else if !cached { + t.Errorf("expected cached response") } expectedInt := 42 - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "baz"); !found { t.Error("could not find value for \"baz\"") } else if value != expectedInt { t.Errorf("expected value %d, got %d", expectedInt, value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "foo"); found { t.Errorf("should not have found value for \"foo\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "signaling", "invalid"); found { t.Errorf("should not have found value for \"invalid\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") } - if value, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found { + if value, cached, found := capabilities.GetIntegerConfig(ctx, url, "invalid", "baz"); found { t.Errorf("should not have found value for \"baz\", got %d", value) + } else if !cached { + t.Errorf("expected cached response") + } +} + +func TestInvalidateCapabilities(t *testing.T) { + var called uint32 + url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse) { + atomic.AddUint32(&called, 1) + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + expectedString := "bar" + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 1 { + t.Errorf("expected called %d, got %d", 1, value) + } + + // Invalidating will cause the capabilities to be reloaded. + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } + + // Invalidating is throttled to about once per minute. + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if !cached { + t.Errorf("expected cached response") + } + + if value := atomic.LoadUint32(&called); value != 2 { + t.Errorf("expected called %d, got %d", 2, value) + } + + // At a later time, invalidating can be done again. + SetCapabilitiesGetNow(t, func() time.Time { + return time.Now().Add(2 * time.Minute) + }) + + capabilities.InvalidateCapabilities(url) + + if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); !found { + t.Error("could not find value for \"foo\"") + } else if value != expectedString { + t.Errorf("expected value %s, got %s", expectedString, value) + } else if cached { + t.Errorf("expected direct response") + } + + if value := atomic.LoadUint32(&called); value != 3 { + t.Errorf("expected called %d, got %d", 3, value) } } diff --git a/hub.go b/hub.go index 2b982ca..c7dc1af 100644 --- a/hub.go +++ b/hub.go @@ -1063,9 +1063,17 @@ func (h *Hub) processHelloV2(client *Client, message *ClientMessage) (*Backend, ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout) defer cancel() - keyData, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) + keyData, cached, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) if !found { - return nil, fmt.Errorf("No key found for issuer") + if cached { + // The Nextcloud instance might just have enabled JWT but we probably use + // the cached capabilities without the public key. Make sure to re-fetch. + h.backend.capabilities.InvalidateCapabilities(url) + keyData, _, found = h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey) + } + if !found { + return nil, fmt.Errorf("No key found for issuer") + } } key, err := loadKeyFunc([]byte(keyData)) diff --git a/hub_test.go b/hub_test.go index 9e56a7e..abe6903 100644 --- a/hub_test.go +++ b/hub_test.go @@ -697,7 +697,11 @@ func registerBackendHandlerUrl(t *testing.T, router *mux.Router, url string) { if strings.Contains(t.Name(), "MultiRoom") { signaling[ConfigKeySessionPingLimit] = 2 } - if strings.Contains(t.Name(), "V2") { + useV2 := true + if os.Getenv("SKIP_V2_CAPABILITIES") != "" { + useV2 = false + } + if strings.Contains(t.Name(), "V2") && useV2 { key := getPublicAuthToken(t) public, err := x509.MarshalPKIXPublicKey(key) if err != nil { @@ -1060,6 +1064,59 @@ func TestClientHelloV2_ExpiresAtMissing(t *testing.T) { } } +func TestClientHelloV2_CachedCapabilities(t *testing.T) { + for _, algo := range testHelloV2Algorithms { + t.Run(algo, func(t *testing.T) { + hub, _, _, server := CreateHubForTest(t) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + // Simulate old-style Nextcloud without capabilities for Hello V2. + t.Setenv("SKIP_V2_CAPABILITIES", "1") + + client1 := NewTestClient(t, server, hub) + defer client1.CloseWithBye() + + if err := client1.SendHelloV1(testDefaultUserId + "1"); err != nil { + t.Fatal(err) + } + + hello1, err := client1.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + if hello1.Hello.UserId != testDefaultUserId+"1" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"1", hello1.Hello) + } + if hello1.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello1.Hello) + } + + // Simulate updated Nextcloud with capabilities for Hello V2. + t.Setenv("SKIP_V2_CAPABILITIES", "") + + client2 := NewTestClient(t, server, hub) + defer client2.CloseWithBye() + + if err := client2.SendHelloV2(testDefaultUserId + "2"); err != nil { + t.Fatal(err) + } + + hello2, err := client2.RunUntilHello(ctx) + if err != nil { + t.Fatal(err) + } + if hello2.Hello.UserId != testDefaultUserId+"2" { + t.Errorf("Expected \"%s\", got %+v", testDefaultUserId+"2", hello2.Hello) + } + if hello2.Hello.SessionId == "" { + t.Errorf("Expected session id, got %+v", hello2.Hello) + } + }) + } +} + func TestClientHelloWithSpaces(t *testing.T) { hub, _, _, server := CreateHubForTest(t) diff --git a/room_ping.go b/room_ping.go index 2e83fe9..48c301a 100644 --- a/room_ping.go +++ b/room_ping.go @@ -119,7 +119,7 @@ func (p *RoomPing) publishEntries(entries *pingEntries, timeout time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - limit, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit) + limit, _, found := p.capabilities.GetIntegerConfig(ctx, entries.url, ConfigGroupSignaling, ConfigKeySessionPingLimit) if !found || limit <= 0 { // Limit disabled while waiting for the next iteration, fallback to sending // one request per room. @@ -188,7 +188,7 @@ func (p *RoomPing) sendPingsCombined(url *url.URL, entries []BackendPingEntry, l } func (p *RoomPing) SendPings(ctx context.Context, room *Room, url *url.URL, entries []BackendPingEntry) error { - limit, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit) + limit, _, found := p.capabilities.GetIntegerConfig(ctx, url, ConfigGroupSignaling, ConfigKeySessionPingLimit) if !found || limit <= 0 { // Old-style Nextcloud or session limit not configured. Perform one request // per room. Don't queue to avoid sending all ping requests to old-style