crypto: delete old olm sessions if there are too many (#315)
Some checks are pending
Go / Lint (latest) (push) Waiting to run
Go / Build (old, libolm) (push) Waiting to run
Go / Build (latest, libolm) (push) Waiting to run
Go / Build (old, goolm) (push) Waiting to run
Go / Build (latest, goolm) (push) Waiting to run

This commit is contained in:
Tulir Asokan 2024-11-20 14:03:21 +01:00 committed by GitHub
commit 9373794606
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 59 additions and 3 deletions

View file

@ -11,6 +11,7 @@ import (
"encoding/json"
"errors"
"fmt"
"slices"
"time"
"github.com/rs/zerolog"
@ -74,6 +75,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
return nil, UnsupportedOlmMessageType
}
log := mach.machOrContextLog(ctx).With().
Stringer("sender_key", senderKey).
Int("olm_msg_type", int(olmType)).
Logger()
ctx = log.WithContext(ctx)
endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second)
plaintext, err := mach.tryDecryptOlmCiphertext(ctx, evt.Sender, senderKey, olmType, ciphertext)
endTimeTrace()
@ -168,6 +174,8 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
return plaintext, nil
}
const MaxOlmSessionsPerDevice = 5
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
@ -176,6 +184,31 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
if err != nil {
return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err)
}
if len(sessions) > MaxOlmSessionsPerDevice*2 {
// SQL store sorts sessions, but other implementations may not, so re-sort just in case
slices.SortFunc(sessions, func(a, b *OlmSession) int {
return b.LastDecryptedTime.Compare(a.LastDecryptedTime)
})
log.Warn().
Int("session_count", len(sessions)).
Time("newest_last_decrypted_at", sessions[0].LastDecryptedTime).
Time("oldest_last_decrypted_at", sessions[len(sessions)-1].LastDecryptedTime).
Msg("Too many sessions, deleting old ones")
for i := MaxOlmSessionsPerDevice; i < len(sessions); i++ {
err = mach.CryptoStore.DeleteSession(ctx, senderKey, sessions[i])
if err != nil {
log.Warn().Err(err).
Stringer("olm_session_id", sessions[i].ID()).
Time("last_decrypt", sessions[i].LastDecryptedTime).
Msg("Failed to delete olm session")
} else {
log.Debug().
Stringer("olm_session_id", sessions[i].ID()).
Time("last_decrypt", sessions[i].LastDecryptedTime).
Msg("Deleted olm session")
}
}
}
for _, session := range sessions {
log := log.With().Str("olm_session_id", session.ID().String()).Logger()
@ -190,11 +223,13 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
continue
}
}
log.Debug().Str("session_description", session.Describe()).Msg("Trying to decrypt olm message")
endTimeTrace = mach.timeTrace(ctx, "decrypting olm message", time.Second)
plaintext, err := session.Decrypt(ciphertext, olmType)
endTimeTrace()
if err != nil {
log.Warn().Err(err).
Str("session_description", session.Describe()).
Msg("Failed to decrypt olm message")
if olmType == id.OlmMsgTypePreKey {
return nil, DecryptionFailedWithMatchingSession
}
@ -205,7 +240,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
if err != nil {
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
}
log.Debug().Msg("Decrypted olm message")
log.Debug().Str("session_description", session.Describe()).Msg("Decrypted olm message")
return plaintext, nil
}
}

View file

@ -219,7 +219,7 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session
return data
}
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
// GetLatestSession retrieves the Olm session for a given sender key from the database that had the most recent successful decryption.
func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) {
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
@ -274,6 +274,11 @@ func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey,
return err
}
func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error {
_, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_session WHERE session_id=$1 AND account_id=$2", session.ID(), store.AccountID)
return err
}
func datePtr(t time.Time) *time.Time {
if t.IsZero() {
return nil

View file

@ -9,6 +9,7 @@ package crypto
import (
"context"
"fmt"
"slices"
"sort"
"sync"
@ -47,6 +48,8 @@ type Store interface {
GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error)
// UpdateSession updates a session that has previously been inserted with AddSession.
UpdateSession(context.Context, id.SenderKey, *OlmSession) error
// DeleteSession deletes the given session that has been previously inserted with AddSession.
DeleteSession(context.Context, id.SenderKey, *OlmSession) error
// PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted
// with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace
@ -233,6 +236,19 @@ func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, ses
return gs.save()
}
func (gs *MemoryStore) DeleteSession(ctx context.Context, senderKey id.SenderKey, target *OlmSession) error {
gs.lock.Lock()
defer gs.lock.Unlock()
sessions, ok := gs.Sessions[senderKey]
if !ok {
return nil
}
gs.Sessions[senderKey] = slices.DeleteFunc(sessions, func(session *OlmSession) bool {
return session == target
})
return gs.save()
}
func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()