mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
federation: move server name cache to separate type
This commit is contained in:
parent
441349efac
commit
36781e7de4
5 changed files with 100 additions and 39 deletions
|
|
@ -103,7 +103,7 @@ func (prov *ProvisioningAPI) Init() {
|
|||
prov.logins = make(map[string]*ProvLogin)
|
||||
prov.net = prov.br.Bridge.Network
|
||||
prov.log = prov.br.Log.With().Str("component", "provisioning").Logger()
|
||||
prov.fedClient = federation.NewClient("", nil)
|
||||
prov.fedClient = federation.NewClient("", nil, nil)
|
||||
prov.fedClient.HTTP.Timeout = 20 * time.Second
|
||||
tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport)
|
||||
tp.Dialer.Timeout = 10 * time.Second
|
||||
|
|
|
|||
71
federation/cache.go
Normal file
71
federation/cache.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (c) 2025 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package federation
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ResolutionCache is an interface for caching resolved server names.
|
||||
type ResolutionCache interface {
|
||||
StoreResolution(*ResolvedServerName)
|
||||
// LoadResolution loads a resolved server name from the cache.
|
||||
// Expired entries MUST NOT be returned.
|
||||
LoadResolution(serverName string) (*ResolvedServerName, error)
|
||||
}
|
||||
|
||||
type KeyCache interface {
|
||||
StoreKeys(*ServerKeyResponse)
|
||||
LoadKeys(serverName string) (*ServerKeyResponse, error)
|
||||
}
|
||||
|
||||
type InMemoryCache struct {
|
||||
resolutions map[string]*ResolvedServerName
|
||||
resolutionsLock sync.RWMutex
|
||||
keys map[string]*ServerKeyResponse
|
||||
keysLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewInMemoryCache() *InMemoryCache {
|
||||
return &InMemoryCache{
|
||||
resolutions: make(map[string]*ResolvedServerName),
|
||||
keys: make(map[string]*ServerKeyResponse),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) {
|
||||
c.resolutionsLock.Lock()
|
||||
defer c.resolutionsLock.Unlock()
|
||||
c.resolutions[resolution.ServerName] = resolution
|
||||
}
|
||||
|
||||
func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) {
|
||||
c.resolutionsLock.RLock()
|
||||
defer c.resolutionsLock.RUnlock()
|
||||
resolution, ok := c.resolutions[serverName]
|
||||
if !ok || time.Until(resolution.Expires) < 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return resolution, nil
|
||||
}
|
||||
|
||||
func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) {
|
||||
c.keysLock.Lock()
|
||||
defer c.keysLock.Unlock()
|
||||
c.keys[keys.ServerName] = keys
|
||||
}
|
||||
|
||||
func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) {
|
||||
c.keysLock.RLock()
|
||||
defer c.keysLock.RUnlock()
|
||||
keys, ok := c.keys[serverName]
|
||||
if !ok || time.Until(keys.ValidUntilTS.Time) < 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
|
@ -32,10 +32,10 @@ type Client struct {
|
|||
Key *SigningKey
|
||||
}
|
||||
|
||||
func NewClient(serverName string, key *SigningKey) *Client {
|
||||
func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client {
|
||||
return &Client{
|
||||
HTTP: &http.Client{
|
||||
Transport: NewServerResolvingTransport(),
|
||||
Transport: NewServerResolvingTransport(cache),
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
UserAgent: mautrix.DefaultUserAgent,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
func TestClient_Version(t *testing.T) {
|
||||
cli := federation.NewClient("", nil)
|
||||
cli := federation.NewClient("", nil, nil)
|
||||
resp, err := cli.Version(context.TODO(), "maunium.net")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Synapse", resp.Server.Name)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests.
|
||||
|
|
@ -22,17 +21,20 @@ type ServerResolvingTransport struct {
|
|||
Transport *http.Transport
|
||||
Dialer *net.Dialer
|
||||
|
||||
cache map[string]*ResolvedServerName
|
||||
resolveLocks map[string]*sync.Mutex
|
||||
cacheLock sync.Mutex
|
||||
cache ResolutionCache
|
||||
|
||||
resolveLocks map[string]*sync.Mutex
|
||||
resolveLocksLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewServerResolvingTransport() *ServerResolvingTransport {
|
||||
func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport {
|
||||
if cache == nil {
|
||||
cache = NewInMemoryCache()
|
||||
}
|
||||
srt := &ServerResolvingTransport{
|
||||
cache: make(map[string]*ResolvedServerName),
|
||||
resolveLocks: make(map[string]*sync.Mutex),
|
||||
|
||||
Dialer: &net.Dialer{},
|
||||
cache: cache,
|
||||
Dialer: &net.Dialer{},
|
||||
}
|
||||
srt.Transport = &http.Transport{
|
||||
DialContext: srt.DialContext,
|
||||
|
|
@ -72,37 +74,25 @@ func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Res
|
|||
}
|
||||
|
||||
func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) {
|
||||
res, lock := srt.getResolveCache(serverName)
|
||||
if res != nil {
|
||||
return res, nil
|
||||
srt.resolveLocksLock.Lock()
|
||||
lock, ok := srt.resolveLocks[serverName]
|
||||
if !ok {
|
||||
lock = &sync.Mutex{}
|
||||
srt.resolveLocks[serverName] = lock
|
||||
}
|
||||
srt.resolveLocksLock.Unlock()
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
res, _ = srt.getResolveCache(serverName)
|
||||
if res != nil {
|
||||
res, err := srt.cache.LoadResolution(serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read cache: %w", err)
|
||||
} else if res != nil {
|
||||
return res, nil
|
||||
} else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
srt.cache.StoreResolution(res)
|
||||
return res, nil
|
||||
}
|
||||
var err error
|
||||
res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
srt.cacheLock.Lock()
|
||||
srt.cache[serverName] = res
|
||||
srt.cacheLock.Unlock()
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) {
|
||||
srt.cacheLock.Lock()
|
||||
defer srt.cacheLock.Unlock()
|
||||
if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 {
|
||||
return val, nil
|
||||
}
|
||||
rl, ok := srt.resolveLocks[serverName]
|
||||
if !ok {
|
||||
rl = &sync.Mutex{}
|
||||
srt.resolveLocks[serverName] = rl
|
||||
}
|
||||
return nil, rl
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue