mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
crypto: delete old olm sessions if there are too many (#315)
This commit is contained in:
parent
c8e197a4f9
commit
9373794606
3 changed files with 59 additions and 3 deletions
|
|
@ -11,6 +11,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
|
@ -74,6 +75,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
|
|||
return nil, UnsupportedOlmMessageType
|
||||
}
|
||||
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Stringer("sender_key", senderKey).
|
||||
Int("olm_msg_type", int(olmType)).
|
||||
Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second)
|
||||
plaintext, err := mach.tryDecryptOlmCiphertext(ctx, evt.Sender, senderKey, olmType, ciphertext)
|
||||
endTimeTrace()
|
||||
|
|
@ -168,6 +174,8 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
return plaintext, nil
|
||||
}
|
||||
|
||||
const MaxOlmSessionsPerDevice = 5
|
||||
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
log := *zerolog.Ctx(ctx)
|
||||
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
|
||||
|
|
@ -176,6 +184,31 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err)
|
||||
}
|
||||
if len(sessions) > MaxOlmSessionsPerDevice*2 {
|
||||
// SQL store sorts sessions, but other implementations may not, so re-sort just in case
|
||||
slices.SortFunc(sessions, func(a, b *OlmSession) int {
|
||||
return b.LastDecryptedTime.Compare(a.LastDecryptedTime)
|
||||
})
|
||||
log.Warn().
|
||||
Int("session_count", len(sessions)).
|
||||
Time("newest_last_decrypted_at", sessions[0].LastDecryptedTime).
|
||||
Time("oldest_last_decrypted_at", sessions[len(sessions)-1].LastDecryptedTime).
|
||||
Msg("Too many sessions, deleting old ones")
|
||||
for i := MaxOlmSessionsPerDevice; i < len(sessions); i++ {
|
||||
err = mach.CryptoStore.DeleteSession(ctx, senderKey, sessions[i])
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Stringer("olm_session_id", sessions[i].ID()).
|
||||
Time("last_decrypt", sessions[i].LastDecryptedTime).
|
||||
Msg("Failed to delete olm session")
|
||||
} else {
|
||||
log.Debug().
|
||||
Stringer("olm_session_id", sessions[i].ID()).
|
||||
Time("last_decrypt", sessions[i].LastDecryptedTime).
|
||||
Msg("Deleted olm session")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
log := log.With().Str("olm_session_id", session.ID().String()).Logger()
|
||||
|
|
@ -190,11 +223,13 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
|
|||
continue
|
||||
}
|
||||
}
|
||||
log.Debug().Str("session_description", session.Describe()).Msg("Trying to decrypt olm message")
|
||||
endTimeTrace = mach.timeTrace(ctx, "decrypting olm message", time.Second)
|
||||
plaintext, err := session.Decrypt(ciphertext, olmType)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Str("session_description", session.Describe()).
|
||||
Msg("Failed to decrypt olm message")
|
||||
if olmType == id.OlmMsgTypePreKey {
|
||||
return nil, DecryptionFailedWithMatchingSession
|
||||
}
|
||||
|
|
@ -205,7 +240,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
|
|||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
|
||||
}
|
||||
log.Debug().Msg("Decrypted olm message")
|
||||
log.Debug().Str("session_description", session.Describe()).Msg("Decrypted olm message")
|
||||
return plaintext, nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session
|
|||
return data
|
||||
}
|
||||
|
||||
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
|
||||
// GetLatestSession retrieves the Olm session for a given sender key from the database that had the most recent successful decryption.
|
||||
func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
|
|
@ -274,6 +274,11 @@ func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey,
|
|||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error {
|
||||
_, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_session WHERE session_id=$1 AND account_id=$2", session.ID(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
func datePtr(t time.Time) *time.Time {
|
||||
if t.IsZero() {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ package crypto
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
|
|
@ -47,6 +48,8 @@ type Store interface {
|
|||
GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error)
|
||||
// UpdateSession updates a session that has previously been inserted with AddSession.
|
||||
UpdateSession(context.Context, id.SenderKey, *OlmSession) error
|
||||
// DeleteSession deletes the given session that has been previously inserted with AddSession.
|
||||
DeleteSession(context.Context, id.SenderKey, *OlmSession) error
|
||||
|
||||
// PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted
|
||||
// with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace
|
||||
|
|
@ -233,6 +236,19 @@ func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, ses
|
|||
return gs.save()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) DeleteSession(ctx context.Context, senderKey id.SenderKey, target *OlmSession) error {
|
||||
gs.lock.Lock()
|
||||
defer gs.lock.Unlock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
gs.Sessions[senderKey] = slices.DeleteFunc(sessions, func(session *OlmSession) bool {
|
||||
return session == target
|
||||
})
|
||||
return gs.save()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error {
|
||||
// we don't need to do anything here because the session is a pointer and already stored in our map
|
||||
return gs.save()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue