crypto/decryptolm: store olm hashes to prevent errors if they're repeated

This commit is contained in:
Tulir Asokan 2024-12-20 14:38:24 +02:00
commit e844153658
6 changed files with 148 additions and 4 deletions

View file

@ -8,6 +8,8 @@ package crypto
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -29,6 +31,7 @@ var (
SenderMismatch = errors.New("mismatched sender in olm payload")
RecipientMismatch = errors.New("mismatched recipient in olm payload")
RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
ErrDuplicateMessage = errors.New("duplicate olm message")
)
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
@ -113,14 +116,35 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
return &olmEvt, nil
}
func olmMessageHash(ciphertext string) ([32]byte, error) {
ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext)
return sha256.Sum256(ciphertextBytes), err
}
func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
ciphertextHash, err := olmMessageHash(ciphertext)
if err != nil {
return nil, fmt.Errorf("failed to hash olm ciphertext: %w", err)
}
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second)
mach.olmLock.Lock()
endTimeTrace()
defer mach.olmLock.Unlock()
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext)
duplicateTS, err := mach.CryptoStore.GetOlmHash(ctx, ciphertextHash)
if err != nil {
log.Warn().Err(err).Msg("Failed to check for duplicate olm message")
} else if !duplicateTS.IsZero() {
log.Warn().
Hex("ciphertext_hash", ciphertextHash[:]).
Time("duplicate_ts", duplicateTS).
Msg("Ignoring duplicate olm message")
return nil, ErrDuplicateMessage
}
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
if err != nil {
if err == DecryptionFailedWithMatchingSession {
log.Warn().Msg("Found matching session, but decryption failed")
@ -153,6 +177,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
}
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
log.Debug().
Hex("ciphertext_hash", ciphertextHash[:]).
Str("olm_session_description", session.Describe()).
Msg("Created inbound olm session")
ctx = log.WithContext(ctx)
@ -166,6 +191,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
}
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now())
if err != nil {
log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting")
}
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
endTimeTrace()
if err != nil {
@ -176,7 +205,9 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
const MaxOlmSessionsPerDevice = 5
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, ciphertextHash [32]byte,
) ([]byte, error) {
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey)
@ -229,6 +260,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
endTimeTrace()
if err != nil {
log.Warn().Err(err).
Hex("ciphertext_hash", ciphertextHash[:]).
Str("session_description", session.Describe()).
Msg("Failed to decrypt olm message")
if olmType == id.OlmMsgTypePreKey {
@ -236,12 +268,19 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
}
} else {
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now())
if err != nil {
log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting")
}
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
endTimeTrace()
if err != nil {
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
}
log.Debug().Str("session_description", session.Describe()).Msg("Decrypted olm message")
log.Debug().
Hex("ciphertext_hash", ciphertextHash[:]).
Str("session_description", session.Describe()).
Msg("Decrypted olm message")
return plaintext, nil
}
}

View file

@ -63,6 +63,9 @@ type OlmMachine struct {
devicesToUnwedgeLock sync.Mutex
recentlyUnwedged map[id.IdentityKey]time.Time
recentlyUnwedgedLock sync.Mutex
olmHashSavePoints []time.Time
lastHashDelete time.Time
olmHashSavePointLock sync.Mutex
olmLock sync.Mutex
megolmEncryptLock sync.Mutex
@ -312,6 +315,7 @@ func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.R
}
mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
mach.MarkOlmHashSavePoint(ctx)
return true
}
@ -399,6 +403,35 @@ func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Eve
}
}
const olmHashSavePointCount = 5
const olmHashDeleteMinInterval = 10 * time.Minute
const minSavePointInterval = 1 * time.Minute
// MarkOlmHashSavePoint marks the current time as a save point for olm hashes and deletes old hashes if needed.
//
// This should be called after all to-device events in a sync have been processed.
// The function will then delete old olm hashes after enough syncs have happened
// (such that it's unlikely for the olm messages to repeat).
func (mach *OlmMachine) MarkOlmHashSavePoint(ctx context.Context) {
mach.olmHashSavePointLock.Lock()
defer mach.olmHashSavePointLock.Unlock()
if len(mach.olmHashSavePoints) > 0 && time.Since(mach.olmHashSavePoints[len(mach.olmHashSavePoints)-1]) < minSavePointInterval {
return
}
mach.olmHashSavePoints = append(mach.olmHashSavePoints, time.Now())
if len(mach.olmHashSavePoints) > olmHashSavePointCount {
sp := mach.olmHashSavePoints[0]
mach.olmHashSavePoints = mach.olmHashSavePoints[1:]
if time.Since(mach.lastHashDelete) > olmHashDeleteMinInterval {
err := mach.CryptoStore.DeleteOldOlmHashes(ctx, sp)
mach.lastHashDelete = time.Now()
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete old olm hashes")
}
}
}
}
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
// don't need to add any custom handlers if you use that method.
func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) {

View file

@ -279,6 +279,29 @@ func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey,
return err
}
func (store *SQLCryptoStore) PutOlmHash(ctx context.Context, messageHash [32]byte, receivedAt time.Time) error {
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_message_hash (account_id, received_at, message_hash) VALUES ($1, $2, $3) ON CONFLICT (message_hash) DO NOTHING", store.Account, messageHash[:], receivedAt.UnixMilli())
return err
}
func (store *SQLCryptoStore) GetOlmHash(ctx context.Context, messageHash [32]byte) (receivedAt time.Time, err error) {
var receivedAtInt int64
err = store.DB.QueryRow(ctx, "SELECT received_at FROM crypto_olm_message_hash WHERE message_hash=$1", messageHash).Scan(&receivedAtInt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
receivedAt = time.UnixMilli(receivedAtInt)
return
}
func (store *SQLCryptoStore) DeleteOldOlmHashes(ctx context.Context, beforeTS time.Time) error {
_, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_message_hash WHERE account_id = $1 AND received_at < $2", store.AccountID, beforeTS.UnixMilli())
return err
}
func datePtr(t time.Time) *time.Time {
if t.IsZero() {
return nil

View file

@ -1,4 +1,4 @@
-- v0 -> v16 (compatible with v15+): Latest revision
-- v0 -> v17 (compatible with v15+): Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
@ -45,6 +45,16 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session (
);
CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key);
CREATE TABLE crypto_olm_message_hash (
account_id TEXT NOT NULL,
received_at BIGINT NOT NULL,
message_hash bytea NOT NULL PRIMARY KEY,
CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id)
REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id);
CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
account_id TEXT,
session_id CHAR(43),

View file

@ -0,0 +1,11 @@
-- v17 (compatible with v15+): Add table for decrypted Olm message hashes
CREATE TABLE crypto_olm_message_hash (
account_id TEXT NOT NULL,
received_at BIGINT NOT NULL,
message_hash bytea NOT NULL PRIMARY KEY,
CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id)
REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id);

View file

@ -12,8 +12,10 @@ import (
"slices"
"sort"
"sync"
"time"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exsync"
"golang.org/x/exp/maps"
"maunium.net/go/mautrix/event"
@ -51,6 +53,13 @@ type Store interface {
// DeleteSession deletes the given session that has been previously inserted with AddSession.
DeleteSession(context.Context, id.SenderKey, *OlmSession) error
// PutOlmHash marks a given olm message hash as handled.
PutOlmHash(context.Context, [32]byte, time.Time) error
// GetOlmHash gets the time that a given olm hash was handled.
GetOlmHash(context.Context, [32]byte) (time.Time, error)
// DeleteOldOlmHashes deletes all olm hashes that were handled before the given time.
DeleteOldOlmHashes(context.Context, time.Time) error
// PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted
// with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace
// sessions inserted with this call.
@ -176,6 +185,7 @@ type MemoryStore struct {
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
OutdatedUsers map[id.UserID]struct{}
Secrets map[id.Secret]string
OlmHashes *exsync.Set[[32]byte]
}
var _ Store = (*MemoryStore)(nil)
@ -198,6 +208,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
OutdatedUsers: make(map[id.UserID]struct{}),
Secrets: make(map[id.Secret]string),
OlmHashes: exsync.NewSet[[32]byte](),
}
}
@ -263,6 +274,23 @@ func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) boo
return ok && len(sessions) > 0 && !sessions[0].Expired()
}
func (gs *MemoryStore) PutOlmHash(_ context.Context, hash [32]byte, receivedAt time.Time) error {
gs.OlmHashes.Add(hash)
return nil
}
func (gs *MemoryStore) GetOlmHash(_ context.Context, hash [32]byte) (time.Time, error) {
if gs.OlmHashes.Has(hash) {
// The time isn't that important, so we just return the current time
return time.Now(), nil
}
return time.Time{}, nil
}
func (gs *MemoryStore) DeleteOldOlmHashes(_ context.Context, beforeTS time.Time) error {
return nil
}
func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()