mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Lock olm sessions between encrypting and sending
This commit is contained in:
parent
819fedddbb
commit
575e242018
6 changed files with 144 additions and 44 deletions
|
|
@ -89,6 +89,11 @@ func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroup
|
|||
return session
|
||||
}
|
||||
|
||||
type deviceSessionWrapper struct {
|
||||
session *OlmSession
|
||||
identity *DeviceIdentity
|
||||
}
|
||||
|
||||
// ShareGroupSession shares a group session for a specific room with all the devices of the given user list.
|
||||
//
|
||||
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
|
||||
|
|
@ -105,8 +110,9 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
|||
session = mach.newOutboundGroupSession(roomID)
|
||||
}
|
||||
|
||||
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
|
||||
withheldCount := 0
|
||||
toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
|
||||
olmSessions := make(map[id.UserID]map[id.DeviceID]deviceSessionWrapper)
|
||||
missingSessions := make(map[id.UserID]map[id.DeviceID]*DeviceIdentity)
|
||||
missingUserSessions := make(map[id.DeviceID]*DeviceIdentity)
|
||||
var fetchKeys []id.UserID
|
||||
|
|
@ -122,9 +128,10 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
|||
mach.Log.Trace("%s has no devices, skipping", userID)
|
||||
} else {
|
||||
mach.Log.Trace("Trying to encrypt group session %s for %s", session.ID(), userID)
|
||||
toDevice.Messages[userID] = make(map[id.DeviceID]*event.Content)
|
||||
toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content)
|
||||
mach.encryptGroupSessionForUser(session, userID, devices, toDevice.Messages[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
|
||||
olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper)
|
||||
mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
|
||||
withheldCount += len(toDeviceWithheld.Messages[userID])
|
||||
if len(missingUserSessions) > 0 {
|
||||
missingSessions[userID] = missingUserSessions
|
||||
missingUserSessions = make(map[id.DeviceID]*DeviceIdentity)
|
||||
|
|
@ -132,9 +139,6 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
|||
if len(toDeviceWithheld.Messages[userID]) == 0 {
|
||||
delete(toDeviceWithheld.Messages, userID)
|
||||
}
|
||||
if len(toDevice.Messages[userID]) == 0 {
|
||||
delete(toDevice.Messages, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -146,10 +150,12 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
|||
}
|
||||
}
|
||||
|
||||
mach.Log.Trace("Creating missing outbound sessions")
|
||||
err = mach.createOutboundSessions(missingSessions)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
|
||||
if len(missingSessions) > 0 {
|
||||
mach.Log.Trace("Creating missing outbound sessions")
|
||||
err = mach.createOutboundSessions(missingSessions)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for userID, devices := range missingSessions {
|
||||
|
|
@ -157,10 +163,10 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
|||
// No missing sessions
|
||||
continue
|
||||
}
|
||||
output, ok := toDevice.Messages[userID]
|
||||
output, ok := olmSessions[userID]
|
||||
if !ok {
|
||||
output = make(map[id.DeviceID]*event.Content)
|
||||
toDevice.Messages[userID] = output
|
||||
output = make(map[id.DeviceID]deviceSessionWrapper)
|
||||
olmSessions[userID] = output
|
||||
}
|
||||
withheld, ok := toDeviceWithheld.Messages[userID]
|
||||
if !ok {
|
||||
|
|
@ -168,27 +174,29 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
|||
toDeviceWithheld.Messages[userID] = withheld
|
||||
}
|
||||
mach.Log.Trace("Trying to encrypt group session %s for %s (post-fetch retry)", session.ID(), userID)
|
||||
mach.encryptGroupSessionForUser(session, userID, devices, output, withheld, nil)
|
||||
mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil)
|
||||
withheldCount += len(toDeviceWithheld.Messages[userID])
|
||||
if len(toDeviceWithheld.Messages[userID]) == 0 {
|
||||
delete(toDeviceWithheld.Messages, userID)
|
||||
}
|
||||
if len(toDevice.Messages[userID]) == 0 {
|
||||
delete(toDevice.Messages, userID)
|
||||
}
|
||||
}
|
||||
|
||||
mach.Log.Trace("Sending to-device to %d users to share group session for %s", len(toDevice.Messages), roomID)
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
|
||||
err = mach.encryptAndSendGroupSession(session, olmSessions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to share group session: %w", err)
|
||||
}
|
||||
|
||||
mach.Log.Trace("Sending to-device messages to %d users to report withheld keys in %s", len(toDeviceWithheld.Messages), roomID)
|
||||
// TODO remove the next line once clients support m.room_key.withheld
|
||||
_, _ = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
|
||||
if len(toDeviceWithheld.Messages) > 0 {
|
||||
mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID)
|
||||
// TODO remove the next 4 lines once clients support m.room_key.withheld
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err)
|
||||
}
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
|
||||
}
|
||||
}
|
||||
|
||||
mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID)
|
||||
|
|
@ -196,7 +204,32 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
|||
return mach.CryptoStore.AddOutboundGroupSession(session)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) encryptGroupSessionForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*DeviceIdentity, output, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*DeviceIdentity) {
|
||||
func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
|
||||
deviceCount := 0
|
||||
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
|
||||
for userID, sessions := range olmSessions {
|
||||
if len(sessions) == 0 {
|
||||
continue
|
||||
}
|
||||
output := make(map[id.DeviceID]*event.Content)
|
||||
toDevice.Messages[userID] = output
|
||||
for deviceID, device := range sessions {
|
||||
device.session.Lock()
|
||||
// We intentionally defer in a loop as it's the safest way of making sure nothing gets locked permanently.
|
||||
defer device.session.Unlock()
|
||||
content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
|
||||
output[deviceID] = &event.Content{Parsed: content}
|
||||
deviceCount++
|
||||
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
|
||||
}
|
||||
}
|
||||
|
||||
mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID())
|
||||
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
|
||||
return err
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*DeviceIdentity, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*DeviceIdentity) {
|
||||
for deviceID, device := range devices {
|
||||
userKey := UserDevice{UserID: userID, DeviceID: deviceID}
|
||||
if state := session.Users[userKey]; state != OGSNotShared {
|
||||
|
|
@ -233,10 +266,11 @@ func (mach *OlmMachine) encryptGroupSessionForUser(session *OutboundGroupSession
|
|||
missingOutput[deviceID] = device
|
||||
}
|
||||
} else {
|
||||
content := mach.encryptOlmEvent(deviceSession, device, event.ToDeviceRoomKey, session.ShareContent())
|
||||
output[deviceID] = &event.Content{Parsed: content}
|
||||
output[deviceID] = deviceSessionWrapper{
|
||||
session: deviceSession,
|
||||
identity: device,
|
||||
}
|
||||
session.Users[userKey] = OGSAlreadyShared
|
||||
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -288,6 +288,9 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *DeviceIdentity, content ev
|
|||
return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID)
|
||||
}
|
||||
|
||||
olmSess.Lock()
|
||||
defer olmSess.Unlock()
|
||||
|
||||
encrypted := mach.encryptOlmEvent(olmSess, device, event.ToDeviceForwardedRoomKey, content)
|
||||
encryptedContent := &event.Content{Parsed: &encrypted}
|
||||
|
||||
|
|
@ -319,7 +322,7 @@ func (mach *OlmMachine) createGroupSession(senderKey id.SenderKey, signingKey id
|
|||
return
|
||||
}
|
||||
mach.markSessionReceived(sessionID)
|
||||
mach.Log.Trace("Created inbound group session %s/%s/%s", roomID, senderKey, sessionID)
|
||||
mach.Log.Debug("Received inbound group session %s / %s / %s", roomID, senderKey, sessionID)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ package crypto
|
|||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
|
|
@ -42,6 +43,20 @@ type OlmSession struct {
|
|||
Internal olm.Session
|
||||
ExpirationMixin
|
||||
id id.SessionID
|
||||
// This is unexported so gob wouldn't insist on trying to marshaling it
|
||||
lock sync.Locker
|
||||
}
|
||||
|
||||
func (session *OlmSession) SetLock(lock sync.Locker) {
|
||||
session.lock = lock
|
||||
}
|
||||
|
||||
func (session *OlmSession) Lock() {
|
||||
session.lock.Lock()
|
||||
}
|
||||
|
||||
func (session *OlmSession) Unlock() {
|
||||
session.lock.Unlock()
|
||||
}
|
||||
|
||||
func (session *OlmSession) ID() id.SessionID {
|
||||
|
|
@ -54,6 +69,7 @@ func (session *OlmSession) ID() id.SessionID {
|
|||
func wrapSession(session *olm.Session) *OlmSession {
|
||||
return &OlmSession{
|
||||
Internal: *session,
|
||||
lock: &sync.Mutex{},
|
||||
ExpirationMixin: ExpirationMixin{
|
||||
TimeMixin: TimeMixin{
|
||||
CreationTime: time.Now(),
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
|
|
@ -30,6 +31,9 @@ type SQLCryptoStore struct {
|
|||
SyncToken string
|
||||
PickleKey []byte
|
||||
Account *OlmAccount
|
||||
|
||||
olmSessionCache map[id.SenderKey]map[id.SessionID]*OlmSession
|
||||
olmSessionCacheLock sync.Mutex
|
||||
}
|
||||
|
||||
var _ Store = (*SQLCryptoStore)(nil)
|
||||
|
|
@ -44,6 +48,8 @@ func NewSQLCryptoStore(db *sql.DB, dialect string, accountID string, deviceID id
|
|||
PickleKey: pickleKey,
|
||||
AccountID: accountID,
|
||||
DeviceID: deviceID,
|
||||
|
||||
olmSessionCache: make(map[id.SenderKey]map[id.SessionID]*OlmSession),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -124,7 +130,12 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
|
|||
|
||||
// HasSession returns whether there is an Olm session for the given sender key.
|
||||
func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
|
||||
// TODO this may need to be changed if olm sessions start expiring
|
||||
store.olmSessionCacheLock.Lock()
|
||||
cache, ok := store.olmSessionCache[key]
|
||||
store.olmSessionCacheLock.Unlock()
|
||||
if ok && len(cache) > 0 {
|
||||
return true
|
||||
}
|
||||
var sessionID id.SessionID
|
||||
err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
|
||||
key, store.AccountID).Scan(&sessionID)
|
||||
|
|
@ -136,48 +147,83 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
|
|||
|
||||
// GetSessions returns all the known Olm sessions for a sender key.
|
||||
func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) {
|
||||
rows, err := store.DB.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id",
|
||||
rows, err := store.DB.Query("SELECT session_id, session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id",
|
||||
key, store.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list := OlmSessionList{}
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
cache := store.getOlmSessionCache(key)
|
||||
for rows.Next() {
|
||||
sess := OlmSession{Internal: *olm.NewBlankSession()}
|
||||
sess := OlmSession{Internal: *olm.NewBlankSession(), lock: &sync.Mutex{}}
|
||||
var sessionBytes []byte
|
||||
err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
|
||||
var sessionID id.SessionID
|
||||
err := rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.UseTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if existing, ok := cache[sessionID]; ok {
|
||||
list = append(list, existing)
|
||||
} else {
|
||||
err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &sess)
|
||||
cache[sess.ID()] = &sess
|
||||
}
|
||||
err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, &sess)
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.SessionID]*OlmSession {
|
||||
data, ok := store.olmSessionCache[key]
|
||||
if !ok {
|
||||
data = make(map[id.SessionID]*OlmSession)
|
||||
store.olmSessionCache[key] = data
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
|
||||
func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) {
|
||||
row := store.DB.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id DESC LIMIT 1",
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
|
||||
row := store.DB.QueryRow("SELECT session_id, session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id DESC LIMIT 1",
|
||||
key, store.AccountID)
|
||||
sess := OlmSession{Internal: *olm.NewBlankSession()}
|
||||
|
||||
sess := OlmSession{Internal: *olm.NewBlankSession(), lock: &sync.Mutex{}}
|
||||
var sessionBytes []byte
|
||||
err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
|
||||
var sessionID id.SessionID
|
||||
|
||||
err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.UseTime)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
|
||||
|
||||
cache := store.getOlmSessionCache(key)
|
||||
if oldSess, ok := cache[sessionID]; ok {
|
||||
return oldSess, nil
|
||||
} else if err = sess.Internal.Unpickle(sessionBytes, store.PickleKey); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
cache[sessionID] = &sess
|
||||
return &sess, nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddSession persists an Olm session for a sender in the database.
|
||||
func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
|
||||
session.ID(), key, sessionBytes, session.CreationTime, session.UseTime, store.AccountID)
|
||||
store.getOlmSessionCache(key)[session.ID()] = session
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ var ErrGroupSessionWithheld = errors.New("group session has been withheld")
|
|||
// General implementation details:
|
||||
// * Get methods should not return errors if the requested data does not exist in the store, they should simply return nil.
|
||||
// * Update methods may assume that the pointer is the same as what has earlier been added to or fetched from the store.
|
||||
// * OlmSessions should be cached so that the mutex works. Alternatively, implementations can use OlmSession.SetLock to provide a custom mutex implementation.
|
||||
type Store interface {
|
||||
// Flush ensures that everything in the store is persisted to disk.
|
||||
// This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
package mautrix
|
||||
|
||||
const Version = "v0.7.8"
|
||||
const Version = "v0.7.9"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue