federation: move server name cache to separate type

This commit is contained in:
Tulir Asokan 2025-05-03 01:43:56 +03:00
commit 36781e7de4
5 changed files with 100 additions and 39 deletions

View file

@ -103,7 +103,7 @@ func (prov *ProvisioningAPI) Init() {
prov.logins = make(map[string]*ProvLogin)
prov.net = prov.br.Bridge.Network
prov.log = prov.br.Log.With().Str("component", "provisioning").Logger()
prov.fedClient = federation.NewClient("", nil)
prov.fedClient = federation.NewClient("", nil, nil)
prov.fedClient.HTTP.Timeout = 20 * time.Second
tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport)
tp.Dialer.Timeout = 10 * time.Second

71
federation/cache.go Normal file
View file

@ -0,0 +1,71 @@
// Copyright (c) 2025 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation
import (
"sync"
"time"
)
// ResolutionCache is an interface for caching resolved server names.
type ResolutionCache interface {
StoreResolution(*ResolvedServerName)
// LoadResolution loads a resolved server name from the cache.
// Expired entries MUST NOT be returned.
LoadResolution(serverName string) (*ResolvedServerName, error)
}
type KeyCache interface {
StoreKeys(*ServerKeyResponse)
LoadKeys(serverName string) (*ServerKeyResponse, error)
}
type InMemoryCache struct {
resolutions map[string]*ResolvedServerName
resolutionsLock sync.RWMutex
keys map[string]*ServerKeyResponse
keysLock sync.RWMutex
}
func NewInMemoryCache() *InMemoryCache {
return &InMemoryCache{
resolutions: make(map[string]*ResolvedServerName),
keys: make(map[string]*ServerKeyResponse),
}
}
func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) {
c.resolutionsLock.Lock()
defer c.resolutionsLock.Unlock()
c.resolutions[resolution.ServerName] = resolution
}
func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) {
c.resolutionsLock.RLock()
defer c.resolutionsLock.RUnlock()
resolution, ok := c.resolutions[serverName]
if !ok || time.Until(resolution.Expires) < 0 {
return nil, nil
}
return resolution, nil
}
func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) {
c.keysLock.Lock()
defer c.keysLock.Unlock()
c.keys[keys.ServerName] = keys
}
func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) {
c.keysLock.RLock()
defer c.keysLock.RUnlock()
keys, ok := c.keys[serverName]
if !ok || time.Until(keys.ValidUntilTS.Time) < 0 {
return nil, nil
}
return keys, nil
}

View file

@ -32,10 +32,10 @@ type Client struct {
Key *SigningKey
}
func NewClient(serverName string, key *SigningKey) *Client {
func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client {
return &Client{
HTTP: &http.Client{
Transport: NewServerResolvingTransport(),
Transport: NewServerResolvingTransport(cache),
Timeout: 120 * time.Second,
},
UserAgent: mautrix.DefaultUserAgent,

View file

@ -16,7 +16,7 @@ import (
)
func TestClient_Version(t *testing.T) {
cli := federation.NewClient("", nil)
cli := federation.NewClient("", nil, nil)
resp, err := cli.Version(context.TODO(), "maunium.net")
require.NoError(t, err)
require.Equal(t, "Synapse", resp.Server.Name)

View file

@ -12,7 +12,6 @@ import (
"net"
"net/http"
"sync"
"time"
)
// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests.
@ -22,17 +21,20 @@ type ServerResolvingTransport struct {
Transport *http.Transport
Dialer *net.Dialer
cache map[string]*ResolvedServerName
resolveLocks map[string]*sync.Mutex
cacheLock sync.Mutex
cache ResolutionCache
resolveLocks map[string]*sync.Mutex
resolveLocksLock sync.Mutex
}
func NewServerResolvingTransport() *ServerResolvingTransport {
func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport {
if cache == nil {
cache = NewInMemoryCache()
}
srt := &ServerResolvingTransport{
cache: make(map[string]*ResolvedServerName),
resolveLocks: make(map[string]*sync.Mutex),
Dialer: &net.Dialer{},
cache: cache,
Dialer: &net.Dialer{},
}
srt.Transport = &http.Transport{
DialContext: srt.DialContext,
@ -72,37 +74,25 @@ func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Res
}
func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) {
res, lock := srt.getResolveCache(serverName)
if res != nil {
return res, nil
srt.resolveLocksLock.Lock()
lock, ok := srt.resolveLocks[serverName]
if !ok {
lock = &sync.Mutex{}
srt.resolveLocks[serverName] = lock
}
srt.resolveLocksLock.Unlock()
lock.Lock()
defer lock.Unlock()
res, _ = srt.getResolveCache(serverName)
if res != nil {
res, err := srt.cache.LoadResolution(serverName)
if err != nil {
return nil, fmt.Errorf("failed to read cache: %w", err)
} else if res != nil {
return res, nil
} else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil {
return nil, err
} else {
srt.cache.StoreResolution(res)
return res, nil
}
var err error
res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts)
if err != nil {
return nil, err
}
srt.cacheLock.Lock()
srt.cache[serverName] = res
srt.cacheLock.Unlock()
return res, nil
}
func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) {
srt.cacheLock.Lock()
defer srt.cacheLock.Unlock()
if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 {
return val, nil
}
rl, ok := srt.resolveLocks[serverName]
if !ok {
rl = &sync.Mutex{}
srt.resolveLocks[serverName] = rl
}
return nil, rl
}