diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 126d54de..d809d039 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -103,7 +103,7 @@ func (prov *ProvisioningAPI) Init() { prov.logins = make(map[string]*ProvLogin) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() - prov.fedClient = federation.NewClient("", nil) + prov.fedClient = federation.NewClient("", nil, nil) prov.fedClient.HTTP.Timeout = 20 * time.Second tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport) tp.Dialer.Timeout = 10 * time.Second diff --git a/federation/cache.go b/federation/cache.go new file mode 100644 index 00000000..95d096fa --- /dev/null +++ b/federation/cache.go @@ -0,0 +1,71 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "sync" + "time" +) + +// ResolutionCache is an interface for caching resolved server names. +type ResolutionCache interface { + StoreResolution(*ResolvedServerName) + // LoadResolution loads a resolved server name from the cache. + // Expired entries MUST NOT be returned. + LoadResolution(serverName string) (*ResolvedServerName, error) +} + +type KeyCache interface { + StoreKeys(*ServerKeyResponse) + LoadKeys(serverName string) (*ServerKeyResponse, error) +} + +type InMemoryCache struct { + resolutions map[string]*ResolvedServerName + resolutionsLock sync.RWMutex + keys map[string]*ServerKeyResponse + keysLock sync.RWMutex +} + +func NewInMemoryCache() *InMemoryCache { + return &InMemoryCache{ + resolutions: make(map[string]*ResolvedServerName), + keys: make(map[string]*ServerKeyResponse), + } +} + +func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) { + c.resolutionsLock.Lock() + defer c.resolutionsLock.Unlock() + c.resolutions[resolution.ServerName] = resolution +} + +func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) { + c.resolutionsLock.RLock() + defer c.resolutionsLock.RUnlock() + resolution, ok := c.resolutions[serverName] + if !ok || time.Until(resolution.Expires) < 0 { + return nil, nil + } + return resolution, nil +} + +func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) { + c.keysLock.Lock() + defer c.keysLock.Unlock() + c.keys[keys.ServerName] = keys +} + +func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) { + c.keysLock.RLock() + defer c.keysLock.RUnlock() + keys, ok := c.keys[serverName] + if !ok || time.Until(keys.ValidUntilTS.Time) < 0 { + return nil, nil + } + return keys, nil +} diff --git a/federation/client.go b/federation/client.go index 7fc630b7..7aff19c9 100644 --- a/federation/client.go +++ b/federation/client.go @@ -32,10 +32,10 @@ type Client struct { Key *SigningKey } -func NewClient(serverName string, key *SigningKey) *Client { +func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { return &Client{ HTTP: &http.Client{ - Transport: NewServerResolvingTransport(), + Transport: NewServerResolvingTransport(cache), Timeout: 120 * time.Second, }, UserAgent: mautrix.DefaultUserAgent, diff --git a/federation/client_test.go b/federation/client_test.go index ba3c3ed4..ece399ea 100644 --- a/federation/client_test.go +++ b/federation/client_test.go @@ -16,7 +16,7 @@ import ( ) func TestClient_Version(t *testing.T) { - cli := federation.NewClient("", nil) + cli := federation.NewClient("", nil, nil) resp, err := cli.Version(context.TODO(), "maunium.net") require.NoError(t, err) require.Equal(t, "Synapse", resp.Server.Name) diff --git a/federation/httpclient.go b/federation/httpclient.go index d6d97280..cbb1674d 100644 --- a/federation/httpclient.go +++ b/federation/httpclient.go @@ -12,7 +12,6 @@ import ( "net" "net/http" "sync" - "time" ) // ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests. @@ -22,17 +21,20 @@ type ServerResolvingTransport struct { Transport *http.Transport Dialer *net.Dialer - cache map[string]*ResolvedServerName - resolveLocks map[string]*sync.Mutex - cacheLock sync.Mutex + cache ResolutionCache + + resolveLocks map[string]*sync.Mutex + resolveLocksLock sync.Mutex } -func NewServerResolvingTransport() *ServerResolvingTransport { +func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport { + if cache == nil { + cache = NewInMemoryCache() + } srt := &ServerResolvingTransport{ - cache: make(map[string]*ResolvedServerName), resolveLocks: make(map[string]*sync.Mutex), - - Dialer: &net.Dialer{}, + cache: cache, + Dialer: &net.Dialer{}, } srt.Transport = &http.Transport{ DialContext: srt.DialContext, @@ -72,37 +74,25 @@ func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Res } func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) { - res, lock := srt.getResolveCache(serverName) - if res != nil { - return res, nil + srt.resolveLocksLock.Lock() + lock, ok := srt.resolveLocks[serverName] + if !ok { + lock = &sync.Mutex{} + srt.resolveLocks[serverName] = lock } + srt.resolveLocksLock.Unlock() + lock.Lock() defer lock.Unlock() - res, _ = srt.getResolveCache(serverName) - if res != nil { + res, err := srt.cache.LoadResolution(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res != nil { + return res, nil + } else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil { + return nil, err + } else { + srt.cache.StoreResolution(res) return res, nil } - var err error - res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts) - if err != nil { - return nil, err - } - srt.cacheLock.Lock() - srt.cache[serverName] = res - srt.cacheLock.Unlock() - return res, nil -} - -func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) { - srt.cacheLock.Lock() - defer srt.cacheLock.Unlock() - if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 { - return val, nil - } - rl, ok := srt.resolveLocks[serverName] - if !ok { - rl = &sync.Mutex{} - srt.resolveLocks[serverName] = rl - } - return nil, rl }