mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
crypto/decryptolm: store olm hashes to prevent errors if they're repeated
This commit is contained in:
parent
918ed4bf23
commit
e844153658
6 changed files with 148 additions and 4 deletions
|
|
@ -8,6 +8,8 @@ package crypto
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -29,6 +31,7 @@ var (
|
|||
SenderMismatch = errors.New("mismatched sender in olm payload")
|
||||
RecipientMismatch = errors.New("mismatched recipient in olm payload")
|
||||
RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload")
|
||||
ErrDuplicateMessage = errors.New("duplicate olm message")
|
||||
)
|
||||
|
||||
// DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm.
|
||||
|
|
@ -113,14 +116,35 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e
|
|||
return &olmEvt, nil
|
||||
}
|
||||
|
||||
func olmMessageHash(ciphertext string) ([32]byte, error) {
|
||||
ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext)
|
||||
return sha256.Sum256(ciphertextBytes), err
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
ciphertextHash, err := olmMessageHash(ciphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash olm ciphertext: %w", err)
|
||||
}
|
||||
|
||||
log := *zerolog.Ctx(ctx)
|
||||
endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second)
|
||||
mach.olmLock.Lock()
|
||||
endTimeTrace()
|
||||
defer mach.olmLock.Unlock()
|
||||
|
||||
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext)
|
||||
duplicateTS, err := mach.CryptoStore.GetOlmHash(ctx, ciphertextHash)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to check for duplicate olm message")
|
||||
} else if !duplicateTS.IsZero() {
|
||||
log.Warn().
|
||||
Hex("ciphertext_hash", ciphertextHash[:]).
|
||||
Time("duplicate_ts", duplicateTS).
|
||||
Msg("Ignoring duplicate olm message")
|
||||
return nil, ErrDuplicateMessage
|
||||
}
|
||||
|
||||
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash)
|
||||
if err != nil {
|
||||
if err == DecryptionFailedWithMatchingSession {
|
||||
log.Warn().Msg("Found matching session, but decryption failed")
|
||||
|
|
@ -153,6 +177,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
}
|
||||
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
|
||||
log.Debug().
|
||||
Hex("ciphertext_hash", ciphertextHash[:]).
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Msg("Created inbound olm session")
|
||||
ctx = log.WithContext(ctx)
|
||||
|
|
@ -166,6 +191,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
}
|
||||
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
|
||||
err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now())
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting")
|
||||
}
|
||||
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
|
|
@ -176,7 +205,9 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
|
||||
const MaxOlmSessionsPerDevice = 5
|
||||
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(
|
||||
ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, ciphertextHash [32]byte,
|
||||
) ([]byte, error) {
|
||||
log := *zerolog.Ctx(ctx)
|
||||
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
|
||||
sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey)
|
||||
|
|
@ -229,6 +260,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
|
|||
endTimeTrace()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Hex("ciphertext_hash", ciphertextHash[:]).
|
||||
Str("session_description", session.Describe()).
|
||||
Msg("Failed to decrypt olm message")
|
||||
if olmType == id.OlmMsgTypePreKey {
|
||||
|
|
@ -236,12 +268,19 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
|
|||
}
|
||||
} else {
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
|
||||
err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now())
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting")
|
||||
}
|
||||
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
|
||||
}
|
||||
log.Debug().Str("session_description", session.Describe()).Msg("Decrypted olm message")
|
||||
log.Debug().
|
||||
Hex("ciphertext_hash", ciphertextHash[:]).
|
||||
Str("session_description", session.Describe()).
|
||||
Msg("Decrypted olm message")
|
||||
return plaintext, nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,6 +63,9 @@ type OlmMachine struct {
|
|||
devicesToUnwedgeLock sync.Mutex
|
||||
recentlyUnwedged map[id.IdentityKey]time.Time
|
||||
recentlyUnwedgedLock sync.Mutex
|
||||
olmHashSavePoints []time.Time
|
||||
lastHashDelete time.Time
|
||||
olmHashSavePointLock sync.Mutex
|
||||
|
||||
olmLock sync.Mutex
|
||||
megolmEncryptLock sync.Mutex
|
||||
|
|
@ -312,6 +315,7 @@ func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.R
|
|||
}
|
||||
|
||||
mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
|
||||
mach.MarkOlmHashSavePoint(ctx)
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
@ -399,6 +403,35 @@ func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Eve
|
|||
}
|
||||
}
|
||||
|
||||
const olmHashSavePointCount = 5
|
||||
const olmHashDeleteMinInterval = 10 * time.Minute
|
||||
const minSavePointInterval = 1 * time.Minute
|
||||
|
||||
// MarkOlmHashSavePoint marks the current time as a save point for olm hashes and deletes old hashes if needed.
|
||||
//
|
||||
// This should be called after all to-device events in a sync have been processed.
|
||||
// The function will then delete old olm hashes after enough syncs have happened
|
||||
// (such that it's unlikely for the olm messages to repeat).
|
||||
func (mach *OlmMachine) MarkOlmHashSavePoint(ctx context.Context) {
|
||||
mach.olmHashSavePointLock.Lock()
|
||||
defer mach.olmHashSavePointLock.Unlock()
|
||||
if len(mach.olmHashSavePoints) > 0 && time.Since(mach.olmHashSavePoints[len(mach.olmHashSavePoints)-1]) < minSavePointInterval {
|
||||
return
|
||||
}
|
||||
mach.olmHashSavePoints = append(mach.olmHashSavePoints, time.Now())
|
||||
if len(mach.olmHashSavePoints) > olmHashSavePointCount {
|
||||
sp := mach.olmHashSavePoints[0]
|
||||
mach.olmHashSavePoints = mach.olmHashSavePoints[1:]
|
||||
if time.Since(mach.lastHashDelete) > olmHashDeleteMinInterval {
|
||||
err := mach.CryptoStore.DeleteOldOlmHashes(ctx, sp)
|
||||
mach.lastHashDelete = time.Now()
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to delete old olm hashes")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
|
||||
// don't need to add any custom handlers if you use that method.
|
||||
func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) {
|
||||
|
|
|
|||
|
|
@ -279,6 +279,29 @@ func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey,
|
|||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PutOlmHash(ctx context.Context, messageHash [32]byte, receivedAt time.Time) error {
|
||||
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_message_hash (account_id, received_at, message_hash) VALUES ($1, $2, $3) ON CONFLICT (message_hash) DO NOTHING", store.Account, messageHash[:], receivedAt.UnixMilli())
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetOlmHash(ctx context.Context, messageHash [32]byte) (receivedAt time.Time, err error) {
|
||||
var receivedAtInt int64
|
||||
err = store.DB.QueryRow(ctx, "SELECT received_at FROM crypto_olm_message_hash WHERE message_hash=$1", messageHash).Scan(&receivedAtInt)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
receivedAt = time.UnixMilli(receivedAtInt)
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) DeleteOldOlmHashes(ctx context.Context, beforeTS time.Time) error {
|
||||
_, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_message_hash WHERE account_id = $1 AND received_at < $2", store.AccountID, beforeTS.UnixMilli())
|
||||
return err
|
||||
}
|
||||
|
||||
func datePtr(t time.Time) *time.Time {
|
||||
if t.IsZero() {
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
-- v0 -> v16 (compatible with v15+): Latest revision
|
||||
-- v0 -> v17 (compatible with v15+): Latest revision
|
||||
CREATE TABLE IF NOT EXISTS crypto_account (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
|
|
@ -45,6 +45,16 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session (
|
|||
);
|
||||
CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key);
|
||||
|
||||
CREATE TABLE crypto_olm_message_hash (
|
||||
account_id TEXT NOT NULL,
|
||||
received_at BIGINT NOT NULL,
|
||||
message_hash bytea NOT NULL PRIMARY KEY,
|
||||
|
||||
CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id)
|
||||
REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
|
||||
account_id TEXT,
|
||||
session_id CHAR(43),
|
||||
|
|
|
|||
11
crypto/sql_store_upgrade/17-decrypted-olm-messages.sql
Normal file
11
crypto/sql_store_upgrade/17-decrypted-olm-messages.sql
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
-- v17 (compatible with v15+): Add table for decrypted Olm message hashes
|
||||
CREATE TABLE crypto_olm_message_hash (
|
||||
account_id TEXT NOT NULL,
|
||||
received_at BIGINT NOT NULL,
|
||||
message_hash bytea NOT NULL PRIMARY KEY,
|
||||
|
||||
CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id)
|
||||
REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id);
|
||||
|
|
@ -12,8 +12,10 @@ import (
|
|||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
"go.mau.fi/util/exsync"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
|
@ -51,6 +53,13 @@ type Store interface {
|
|||
// DeleteSession deletes the given session that has been previously inserted with AddSession.
|
||||
DeleteSession(context.Context, id.SenderKey, *OlmSession) error
|
||||
|
||||
// PutOlmHash marks a given olm message hash as handled.
|
||||
PutOlmHash(context.Context, [32]byte, time.Time) error
|
||||
// GetOlmHash gets the time that a given olm hash was handled.
|
||||
GetOlmHash(context.Context, [32]byte) (time.Time, error)
|
||||
// DeleteOldOlmHashes deletes all olm hashes that were handled before the given time.
|
||||
DeleteOldOlmHashes(context.Context, time.Time) error
|
||||
|
||||
// PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted
|
||||
// with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace
|
||||
// sessions inserted with this call.
|
||||
|
|
@ -176,6 +185,7 @@ type MemoryStore struct {
|
|||
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
|
||||
OutdatedUsers map[id.UserID]struct{}
|
||||
Secrets map[id.Secret]string
|
||||
OlmHashes *exsync.Set[[32]byte]
|
||||
}
|
||||
|
||||
var _ Store = (*MemoryStore)(nil)
|
||||
|
|
@ -198,6 +208,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
|
|||
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
|
||||
OutdatedUsers: make(map[id.UserID]struct{}),
|
||||
Secrets: make(map[id.Secret]string),
|
||||
OlmHashes: exsync.NewSet[[32]byte](),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -263,6 +274,23 @@ func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) boo
|
|||
return ok && len(sessions) > 0 && !sessions[0].Expired()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutOlmHash(_ context.Context, hash [32]byte, receivedAt time.Time) error {
|
||||
gs.OlmHashes.Add(hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetOlmHash(_ context.Context, hash [32]byte) (time.Time, error) {
|
||||
if gs.OlmHashes.Has(hash) {
|
||||
// The time isn't that important, so we just return the current time
|
||||
return time.Now(), nil
|
||||
}
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) DeleteOldOlmHashes(_ context.Context, beforeTS time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue