Don't update capabilities concurrently from same host.

If capabilities are expired and requested from multiple clients concurrently,
this could cause concurrent (duplicate) requests to the same Nextcloud host.
With this change, only a single request is sent to Nextcloud in such cases.
This commit is contained in:
Joachim Bauch 2024-10-09 15:13:48 +02:00
commit 147fb2305c
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
2 changed files with 124 additions and 62 deletions

View file

@ -60,6 +60,7 @@ var (
)
type capabilitiesEntry struct {
c *Capabilities
mu sync.RWMutex
nextUpdate time.Time
etag string
@ -67,21 +68,24 @@ type capabilitiesEntry struct {
capabilities map[string]interface{}
}
func newCapabilitiesEntry() *capabilitiesEntry {
return &capabilitiesEntry{}
func newCapabilitiesEntry(c *Capabilities) *capabilitiesEntry {
return &capabilitiesEntry{
c: c,
}
}
func (e *capabilitiesEntry) valid(now time.Time) bool {
e.mu.RLock()
defer e.mu.RUnlock()
return e.validLocked(now)
}
func (e *capabilitiesEntry) validLocked(now time.Time) bool {
return e.nextUpdate.After(now)
}
func (e *capabilitiesEntry) updateRequest(r *http.Request) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.etag != "" {
r.Header.Set("If-None-Match", e.etag)
}
@ -94,19 +98,59 @@ func (e *capabilitiesEntry) invalidate() {
e.nextUpdate = time.Now()
}
func (e *capabilitiesEntry) errorIfMustRevalidate(err error) error {
func (e *capabilitiesEntry) errorIfMustRevalidate(err error) (bool, error) {
if !e.mustRevalidate {
return nil
return false, nil
}
e.capabilities = nil
return err
return false, err
}
func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error {
func (e *capabilitiesEntry) update(ctx context.Context, u *url.URL, now time.Time) (bool, error) {
e.mu.Lock()
defer e.mu.Unlock()
if e.validLocked(now) {
// Capabilities were updated while waiting for the lock.
return false, nil
}
capUrl := *u
if !strings.Contains(capUrl.Path, "ocs/v2.php") {
if !strings.HasSuffix(capUrl.Path, "/") {
capUrl.Path += "/"
}
capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities"
} else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 {
capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities"
}
log.Printf("Capabilities expired for %s, updating", capUrl.String())
client, pool, err := e.c.pool.Get(ctx, &capUrl)
if err != nil {
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
return false, err
}
defer pool.Put(client)
req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
if err != nil {
log.Printf("Could not create request to %s: %s", &capUrl, err)
return false, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("OCS-APIRequest", "true")
req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+e.c.version)
e.updateRequest(req)
response, err := client.Do(req)
if err != nil {
return false, err
}
defer response.Body.Close()
url := response.Request.URL
e.etag = response.Header.Get("ETag")
@ -127,7 +171,7 @@ func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error
if response.StatusCode == http.StatusNotModified {
log.Printf("Capabilities %+v from %s have not changed", e.capabilities, url)
return nil
return false, nil
} else if response.StatusCode != http.StatusOK {
log.Printf("Received unexpected HTTP status from %s: %s", url, response.Status)
return e.errorIfMustRevalidate(ErrUnexpectedHttpStatus)
@ -164,19 +208,19 @@ func (e *capabilitiesEntry) update(response *http.Response, now time.Time) error
if !found || len(capaObj) == 0 {
log.Printf("No capabilities received for app spreed from %s: %+v", url, capaResponse)
e.capabilities = nil
return nil
return false, nil
}
var capa map[string]interface{}
if err := json.Unmarshal(capaObj, &capa); err != nil {
log.Printf("Unsupported capabilities received for app spreed from %s: %+v", url, capaResponse)
e.capabilities = nil
return nil
return false, nil
}
log.Printf("Received capabilities %+v from %s", capa, url)
e.capabilities = capa
return nil
return true, nil
}
func (e *capabilitiesEntry) GetCapabilities() map[string]interface{} {
@ -231,11 +275,15 @@ func (c *Capabilities) getCapabilities(key string) (*capabilitiesEntry, bool) {
now := c.getNow()
entry, found := c.entries[key]
if found && entry.valid(now) {
return entry, true
if !found {
// Upgrade to write-lock
c.mu.RUnlock()
defer c.mu.RLock()
entry = c.newCapabilitiesEntry(key)
}
return entry, false
return entry, entry.valid(now)
}
func (c *Capabilities) invalidateCapabilities(key string) {
@ -260,7 +308,7 @@ func (c *Capabilities) newCapabilitiesEntry(key string) *capabilitiesEntry {
entry, found := c.entries[key]
if !found {
entry = newCapabilitiesEntry()
entry = newCapabilitiesEntry(c)
c.entries[key] = entry
}
@ -279,52 +327,12 @@ func (c *Capabilities) loadCapabilities(ctx context.Context, u *url.URL) (map[st
return entry.GetCapabilities(), true, nil
}
capUrl := *u
if !strings.Contains(capUrl.Path, "ocs/v2.php") {
if !strings.HasSuffix(capUrl.Path, "/") {
capUrl.Path += "/"
}
capUrl.Path = capUrl.Path + "ocs/v2.php/cloud/capabilities"
} else if pos := strings.Index(capUrl.Path, "/ocs/v2.php/"); pos >= 0 {
capUrl.Path = capUrl.Path[:pos+11] + "/cloud/capabilities"
}
log.Printf("Capabilities expired for %s, updating", capUrl.String())
client, pool, err := c.pool.Get(ctx, &capUrl)
if err != nil {
log.Printf("Could not get client for host %s: %s", capUrl.Host, err)
return nil, false, err
}
defer pool.Put(client)
req, err := http.NewRequestWithContext(ctx, "GET", capUrl.String(), nil)
if err != nil {
log.Printf("Could not create request to %s: %s", &capUrl, err)
return nil, false, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("OCS-APIRequest", "true")
req.Header.Set("User-Agent", "nextcloud-spreed-signaling/"+c.version)
if entry != nil {
entry.updateRequest(req)
}
resp, err := client.Do(req)
updated, err := entry.update(ctx, u, c.getNow())
if err != nil {
return nil, false, err
}
defer resp.Body.Close()
if entry == nil {
entry = c.newCapabilitiesEntry(key)
}
if err := entry.update(resp, c.getNow()); err != nil {
return nil, false, err
}
return entry.GetCapabilities(), false, nil
return entry.GetCapabilities(), !updated, nil
}
func (c *Capabilities) HasCapabilityFeature(ctx context.Context, u *url.URL, feature string) bool {

View file

@ -32,6 +32,7 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@ -404,7 +405,7 @@ func TestCapabilitiesNoCacheETag(t *testing.T) {
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
assert.True(cached)
}
value = called.Load()
@ -444,7 +445,7 @@ func TestCapabilitiesCacheNoMustRevalidate(t *testing.T) {
// "must-revalidate" is not set.
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
assert.True(cached)
}
value = called.Load()
@ -484,7 +485,7 @@ func TestCapabilitiesNoCacheNoMustRevalidate(t *testing.T) {
// "must-revalidate" is not set.
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
assert.True(cached)
}
value = called.Load()
@ -528,3 +529,56 @@ func TestCapabilitiesNoCacheMustRevalidate(t *testing.T) {
value = called.Load()
assert.EqualValues(2, value)
}
func TestConcurrentExpired(t *testing.T) {
t.Parallel()
CatchLogForTest(t)
assert := assert.New(t)
var called atomic.Uint32
url, capabilities := NewCapabilitiesForTestWithCallback(t, func(cr *CapabilitiesResponse, w http.ResponseWriter) error {
called.Add(1)
return nil
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
expectedString := "bar"
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
assert.False(cached)
}
count := 100
start := make(chan struct{})
var numCached atomic.Uint32
var numFetched atomic.Uint32
var finished sync.WaitGroup
for i := 0; i < count; i++ {
finished.Add(1)
go func() {
defer finished.Done()
<-start
if value, cached, found := capabilities.GetStringConfig(ctx, url, "signaling", "foo"); assert.True(found) {
assert.Equal(expectedString, value)
if cached {
numCached.Add(1)
} else {
numFetched.Add(1)
}
}
}()
}
SetCapabilitiesGetNow(t, capabilities, func() time.Time {
return time.Now().Add(minCapabilitiesCacheDuration)
})
close(start)
finished.Wait()
assert.EqualValues(2, called.Load())
assert.EqualValues(count, numFetched.Load()+numCached.Load())
assert.EqualValues(1, numFetched.Load())
assert.EqualValues(count-1, numCached.Load())
}