Make LruCache typed through generics.

This commit is contained in:
Joachim Bauch 2025-09-30 09:55:57 +02:00
commit 178503fef7
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
4 changed files with 43 additions and 52 deletions

12
hub.go
View file

@ -171,7 +171,7 @@ type Hub struct {
roomPing *RoomPing
virtualSessions map[PublicSessionId]uint64
decodeCaches []*LruCache
decodeCaches []*LruCache[*SessionIdData]
mcu Mcu
mcuTimeout time.Duration
@ -285,9 +285,9 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer
log.Printf("No trusted proxies configured, only allowing for %s", trustedProxiesIps)
}
decodeCaches := make([]*LruCache, 0, numDecodeCaches)
decodeCaches := make([]*LruCache[*SessionIdData], 0, numDecodeCaches)
for range numDecodeCaches {
decodeCaches = append(decodeCaches, NewLruCache(decodeCacheSize))
decodeCaches = append(decodeCaches, NewLruCache[*SessionIdData](decodeCacheSize))
}
roomSessions, err := NewBuiltinRoomSessions(rpcClients)
@ -597,7 +597,7 @@ func (h *Hub) Reload(config *goconf.ConfigFile) {
h.rpcClients.Reload(config)
}
func (h *Hub) getDecodeCache(cache_key string) *LruCache {
func (h *Hub) getDecodeCache(cache_key string) *LruCache[*SessionIdData] {
hash := fnv.New32a()
hash.Write([]byte(cache_key)) // nolint
idx := hash.Sum32() % uint32(len(h.decodeCaches))
@ -648,7 +648,7 @@ func (h *Hub) decodePrivateSessionId(id PrivateSessionId) *SessionIdData {
cache_key := fmt.Sprintf("%s|%s", id, privateSessionName)
cache := h.getDecodeCache(cache_key)
if result := cache.Get(cache_key); result != nil {
return result.(*SessionIdData)
return result
}
data, err := h.cookie.DecodePrivate(id)
@ -668,7 +668,7 @@ func (h *Hub) decodePublicSessionId(id PublicSessionId) *SessionIdData {
cache_key := fmt.Sprintf("%s|%s", id, publicSessionName)
cache := h.getDecodeCache(cache_key)
if result := cache.Get(cache_key); result != nil {
return result.(*SessionIdData)
return result
}
data, err := h.cookie.DecodePublic(id)

35
lru.go
View file

@ -26,36 +26,36 @@ import (
"sync"
)
type cacheEntry struct {
type cacheEntry[T any] struct {
key string
value any
value T
}
type LruCache struct {
type LruCache[T any] struct {
size int
mu sync.Mutex
entries *list.List
data map[string]*list.Element
}
func NewLruCache(size int) *LruCache {
return &LruCache{
func NewLruCache[T any](size int) *LruCache[T] {
return &LruCache[T]{
size: size,
entries: list.New(),
data: make(map[string]*list.Element),
}
}
func (c *LruCache) Set(key string, value any) {
func (c *LruCache[T]) Set(key string, value T) {
c.mu.Lock()
if v, found := c.data[key]; found {
c.entries.MoveToFront(v)
v.Value.(*cacheEntry).value = value
v.Value.(*cacheEntry[T]).value = value
c.mu.Unlock()
return
}
v := c.entries.PushFront(&cacheEntry{
v := c.entries.PushFront(&cacheEntry[T]{
key: key,
value: value,
})
@ -66,20 +66,21 @@ func (c *LruCache) Set(key string, value any) {
c.mu.Unlock()
}
func (c *LruCache) Get(key string) any {
func (c *LruCache[T]) Get(key string) T {
c.mu.Lock()
if v, found := c.data[key]; found {
c.entries.MoveToFront(v)
value := v.Value.(*cacheEntry).value
value := v.Value.(*cacheEntry[T]).value
c.mu.Unlock()
return value
}
c.mu.Unlock()
return nil
var defaultValue T
return defaultValue
}
func (c *LruCache) Remove(key string) {
func (c *LruCache[T]) Remove(key string) {
c.mu.Lock()
if v, found := c.data[key]; found {
c.removeElement(v)
@ -87,26 +88,26 @@ func (c *LruCache) Remove(key string) {
c.mu.Unlock()
}
func (c *LruCache) removeOldestLocked() {
func (c *LruCache[T]) removeOldestLocked() {
v := c.entries.Back()
if v != nil {
c.removeElement(v)
}
}
func (c *LruCache) RemoveOldest() {
func (c *LruCache[T]) RemoveOldest() {
c.mu.Lock()
c.removeOldestLocked()
c.mu.Unlock()
}
func (c *LruCache) removeElement(e *list.Element) {
func (c *LruCache[T]) removeElement(e *list.Element) {
c.entries.Remove(e)
entry := e.Value.(*cacheEntry)
entry := e.Value.(*cacheEntry[T])
delete(c.data, entry.key)
}
func (c *LruCache) Len() int {
func (c *LruCache[T]) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.entries.Len()

View file

@ -30,7 +30,7 @@ import (
func TestLruUnbound(t *testing.T) {
assert := assert.New(t)
lru := NewLruCache(0)
lru := NewLruCache[int](0)
count := 10
for i := range count {
key := fmt.Sprintf("%d", i)
@ -39,9 +39,8 @@ func TestLruUnbound(t *testing.T) {
assert.Equal(count, lru.Len())
for i := range count {
key := fmt.Sprintf("%d", i)
if value := lru.Get(key); assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
value := lru.Get(key)
assert.EqualValues(i, value, "Failed for %s", key)
}
// The first key ("0") is now the oldest.
lru.RemoveOldest()
@ -49,12 +48,7 @@ func TestLruUnbound(t *testing.T) {
for i := range count {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if i == 0 {
assert.Nil(value, "The value for key %s should have been removed", key)
continue
} else if assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
assert.EqualValues(i, value, "Failed for %s", key)
}
// NOTE: Key "0" no longer exists below, so make sure to not set it again.
@ -68,9 +62,8 @@ func TestLruUnbound(t *testing.T) {
// NOTE: The same ordering as the Set calls above.
for i := count - 1; i >= 1; i-- {
key := fmt.Sprintf("%d", i)
if value := lru.Get(key); assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
}
value := lru.Get(key)
assert.EqualValues(i, value, "Failed for %s", key)
}
// The last key ("9") is now the oldest.
@ -80,10 +73,9 @@ func TestLruUnbound(t *testing.T) {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if i == 0 || i == count-1 {
assert.Nil(value, "The value for key %s should have been removed", key)
continue
} else if assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
assert.EqualValues(0, value, "The value for key %s should have been removed", key)
} else {
assert.EqualValues(i, value, "Failed for %s", key)
}
}
@ -95,10 +87,9 @@ func TestLruUnbound(t *testing.T) {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if i == 0 || i == count-1 || i == count/2 {
assert.Nil(value, "The value for key %s should have been removed", key)
continue
} else if assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
assert.EqualValues(0, value, "The value for key %s should have been removed", key)
} else {
assert.EqualValues(i, value, "Failed for %s", key)
}
}
}
@ -106,7 +97,7 @@ func TestLruUnbound(t *testing.T) {
func TestLruBound(t *testing.T) {
assert := assert.New(t)
size := 2
lru := NewLruCache(size)
lru := NewLruCache[int](size)
count := 10
for i := range count {
key := fmt.Sprintf("%d", i)
@ -118,10 +109,9 @@ func TestLruBound(t *testing.T) {
key := fmt.Sprintf("%d", i)
value := lru.Get(key)
if i < count-size {
assert.Nil(value, "The value for key %s should have been removed", key)
continue
} else if assert.NotNil(value, "No value found for %s", key) {
assert.EqualValues(i, value)
assert.EqualValues(0, value, "The value for key %s should have been removed", key)
} else {
assert.EqualValues(i, value, "Failed for %s", key)
}
}
}

View file

@ -49,7 +49,7 @@ type tokensEtcd struct {
client *signaling.EtcdClient
tokenFormats atomic.Value
tokenCache *signaling.LruCache
tokenCache *signaling.LruCache[*tokenCacheEntry]
}
func NewProxyTokensEtcd(config *goconf.ConfigFile) (ProxyTokens, error) {
@ -64,7 +64,7 @@ func NewProxyTokensEtcd(config *goconf.ConfigFile) (ProxyTokens, error) {
result := &tokensEtcd{
client: client,
tokenCache: signaling.NewLruCache(tokenCacheSize),
tokenCache: signaling.NewLruCache[*tokenCacheEntry](tokenCacheSize),
}
if err := result.load(config, false); err != nil {
return nil, err
@ -98,7 +98,7 @@ func (t *tokensEtcd) getByKey(id string, key string) (*ProxyToken, error) {
}
keyValue := resp.Kvs[len(resp.Kvs)-1].Value
cached, _ := t.tokenCache.Get(key).(*tokenCacheEntry)
cached := t.tokenCache.Get(key)
if cached == nil || !bytes.Equal(cached.keyValue, keyValue) {
// Parsed public keys are cached to avoid the parse overhead.
publicKey, err := jwt.ParseRSAPublicKeyFromPEM(keyValue)