Lock olm sessions between encrypting and sending

This commit is contained in:
Tulir Asokan 2020-10-02 01:07:53 +03:00
commit 575e242018
6 changed files with 144 additions and 44 deletions

View file

@ -89,6 +89,11 @@ func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroup
return session
}
type deviceSessionWrapper struct {
session *OlmSession
identity *DeviceIdentity
}
// ShareGroupSession shares a group session for a specific room with all the devices of the given user list.
//
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
@ -105,8 +110,9 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
session = mach.newOutboundGroupSession(roomID)
}
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
withheldCount := 0
toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
olmSessions := make(map[id.UserID]map[id.DeviceID]deviceSessionWrapper)
missingSessions := make(map[id.UserID]map[id.DeviceID]*DeviceIdentity)
missingUserSessions := make(map[id.DeviceID]*DeviceIdentity)
var fetchKeys []id.UserID
@ -122,9 +128,10 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
mach.Log.Trace("%s has no devices, skipping", userID)
} else {
mach.Log.Trace("Trying to encrypt group session %s for %s", session.ID(), userID)
toDevice.Messages[userID] = make(map[id.DeviceID]*event.Content)
toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content)
mach.encryptGroupSessionForUser(session, userID, devices, toDevice.Messages[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper)
mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(missingUserSessions) > 0 {
missingSessions[userID] = missingUserSessions
missingUserSessions = make(map[id.DeviceID]*DeviceIdentity)
@ -132,9 +139,6 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
if len(toDeviceWithheld.Messages[userID]) == 0 {
delete(toDeviceWithheld.Messages, userID)
}
if len(toDevice.Messages[userID]) == 0 {
delete(toDevice.Messages, userID)
}
}
}
@ -146,10 +150,12 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
}
}
mach.Log.Trace("Creating missing outbound sessions")
err = mach.createOutboundSessions(missingSessions)
if err != nil {
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
if len(missingSessions) > 0 {
mach.Log.Trace("Creating missing outbound sessions")
err = mach.createOutboundSessions(missingSessions)
if err != nil {
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
}
}
for userID, devices := range missingSessions {
@ -157,10 +163,10 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
// No missing sessions
continue
}
output, ok := toDevice.Messages[userID]
output, ok := olmSessions[userID]
if !ok {
output = make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
output = make(map[id.DeviceID]deviceSessionWrapper)
olmSessions[userID] = output
}
withheld, ok := toDeviceWithheld.Messages[userID]
if !ok {
@ -168,27 +174,29 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
toDeviceWithheld.Messages[userID] = withheld
}
mach.Log.Trace("Trying to encrypt group session %s for %s (post-fetch retry)", session.ID(), userID)
mach.encryptGroupSessionForUser(session, userID, devices, output, withheld, nil)
mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil)
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(toDeviceWithheld.Messages[userID]) == 0 {
delete(toDeviceWithheld.Messages, userID)
}
if len(toDevice.Messages[userID]) == 0 {
delete(toDevice.Messages, userID)
}
}
mach.Log.Trace("Sending to-device to %d users to share group session for %s", len(toDevice.Messages), roomID)
_, err = mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
err = mach.encryptAndSendGroupSession(session, olmSessions)
if err != nil {
return fmt.Errorf("failed to share group session: %w", err)
}
mach.Log.Trace("Sending to-device messages to %d users to report withheld keys in %s", len(toDeviceWithheld.Messages), roomID)
// TODO remove the next line once clients support m.room_key.withheld
_, _ = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
if len(toDeviceWithheld.Messages) > 0 {
mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID)
// TODO remove the next 4 lines once clients support m.room_key.withheld
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err)
}
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
}
}
mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID)
@ -196,7 +204,32 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
return mach.CryptoStore.AddOutboundGroupSession(session)
}
func (mach *OlmMachine) encryptGroupSessionForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*DeviceIdentity, output, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*DeviceIdentity) {
func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
for userID, sessions := range olmSessions {
if len(sessions) == 0 {
continue
}
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
device.session.Lock()
// We intentionally defer in a loop as it's the safest way of making sure nothing gets locked permanently.
defer device.session.Unlock()
content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
deviceCount++
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
}
}
mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID())
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
return err
}
func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*DeviceIdentity, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*DeviceIdentity) {
for deviceID, device := range devices {
userKey := UserDevice{UserID: userID, DeviceID: deviceID}
if state := session.Users[userKey]; state != OGSNotShared {
@ -233,10 +266,11 @@ func (mach *OlmMachine) encryptGroupSessionForUser(session *OutboundGroupSession
missingOutput[deviceID] = device
}
} else {
content := mach.encryptOlmEvent(deviceSession, device, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
output[deviceID] = deviceSessionWrapper{
session: deviceSession,
identity: device,
}
session.Users[userKey] = OGSAlreadyShared
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
}
}
}

View file

@ -288,6 +288,9 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *DeviceIdentity, content ev
return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID)
}
olmSess.Lock()
defer olmSess.Unlock()
encrypted := mach.encryptOlmEvent(olmSess, device, event.ToDeviceForwardedRoomKey, content)
encryptedContent := &event.Content{Parsed: &encrypted}
@ -319,7 +322,7 @@ func (mach *OlmMachine) createGroupSession(senderKey id.SenderKey, signingKey id
return
}
mach.markSessionReceived(sessionID)
mach.Log.Trace("Created inbound group session %s/%s/%s", roomID, senderKey, sessionID)
mach.Log.Debug("Received inbound group session %s / %s / %s", roomID, senderKey, sessionID)
}
func (mach *OlmMachine) markSessionReceived(id id.SessionID) {

View file

@ -9,6 +9,7 @@ package crypto
import (
"errors"
"strings"
"sync"
"time"
"maunium.net/go/mautrix/crypto/olm"
@ -42,6 +43,20 @@ type OlmSession struct {
Internal olm.Session
ExpirationMixin
id id.SessionID
// This is unexported so gob wouldn't insist on trying to marshaling it
lock sync.Locker
}
func (session *OlmSession) SetLock(lock sync.Locker) {
session.lock = lock
}
func (session *OlmSession) Lock() {
session.lock.Lock()
}
func (session *OlmSession) Unlock() {
session.lock.Unlock()
}
func (session *OlmSession) ID() id.SessionID {
@ -54,6 +69,7 @@ func (session *OlmSession) ID() id.SessionID {
func wrapSession(session *olm.Session) *OlmSession {
return &OlmSession{
Internal: *session,
lock: &sync.Mutex{},
ExpirationMixin: ExpirationMixin{
TimeMixin: TimeMixin{
CreationTime: time.Now(),

View file

@ -10,6 +10,7 @@ import (
"database/sql"
"fmt"
"strings"
"sync"
"github.com/lib/pq"
@ -30,6 +31,9 @@ type SQLCryptoStore struct {
SyncToken string
PickleKey []byte
Account *OlmAccount
olmSessionCache map[id.SenderKey]map[id.SessionID]*OlmSession
olmSessionCacheLock sync.Mutex
}
var _ Store = (*SQLCryptoStore)(nil)
@ -44,6 +48,8 @@ func NewSQLCryptoStore(db *sql.DB, dialect string, accountID string, deviceID id
PickleKey: pickleKey,
AccountID: accountID,
DeviceID: deviceID,
olmSessionCache: make(map[id.SenderKey]map[id.SessionID]*OlmSession),
}
}
@ -124,7 +130,12 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
// HasSession returns whether there is an Olm session for the given sender key.
func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
// TODO this may need to be changed if olm sessions start expiring
store.olmSessionCacheLock.Lock()
cache, ok := store.olmSessionCache[key]
store.olmSessionCacheLock.Unlock()
if ok && len(cache) > 0 {
return true
}
var sessionID id.SessionID
err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
key, store.AccountID).Scan(&sessionID)
@ -136,48 +147,83 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
// GetSessions returns all the known Olm sessions for a sender key.
func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) {
rows, err := store.DB.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id",
rows, err := store.DB.Query("SELECT session_id, session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id",
key, store.AccountID)
if err != nil {
return nil, err
}
list := OlmSessionList{}
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
cache := store.getOlmSessionCache(key)
for rows.Next() {
sess := OlmSession{Internal: *olm.NewBlankSession()}
sess := OlmSession{Internal: *olm.NewBlankSession(), lock: &sync.Mutex{}}
var sessionBytes []byte
err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
var sessionID id.SessionID
err := rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.UseTime)
if err != nil {
return nil, err
} else if existing, ok := cache[sessionID]; ok {
list = append(list, existing)
} else {
err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return nil, err
}
list = append(list, &sess)
cache[sess.ID()] = &sess
}
err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return nil, err
}
list = append(list, &sess)
}
return list, nil
}
func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.SessionID]*OlmSession {
data, ok := store.olmSessionCache[key]
if !ok {
data = make(map[id.SessionID]*OlmSession)
store.olmSessionCache[key] = data
}
return data
}
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) {
row := store.DB.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id DESC LIMIT 1",
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
row := store.DB.QueryRow("SELECT session_id, session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY session_id DESC LIMIT 1",
key, store.AccountID)
sess := OlmSession{Internal: *olm.NewBlankSession()}
sess := OlmSession{Internal: *olm.NewBlankSession(), lock: &sync.Mutex{}}
var sessionBytes []byte
err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
var sessionID id.SessionID
err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.UseTime)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
cache := store.getOlmSessionCache(key)
if oldSess, ok := cache[sessionID]; ok {
return oldSess, nil
} else if err = sess.Internal.Unpickle(sessionBytes, store.PickleKey); err != nil {
return nil, err
} else {
cache[sessionID] = &sess
return &sess, nil
}
}
// AddSession persists an Olm session for a sender in the database.
func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error {
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
session.ID(), key, sessionBytes, session.CreationTime, session.UseTime, store.AccountID)
store.getOlmSessionCache(key)[session.ID()] = session
return err
}

View file

@ -66,6 +66,7 @@ var ErrGroupSessionWithheld = errors.New("group session has been withheld")
// General implementation details:
// * Get methods should not return errors if the requested data does not exist in the store, they should simply return nil.
// * Update methods may assume that the pointer is the same as what has earlier been added to or fetched from the store.
// * OlmSessions should be cached so that the mutex works. Alternatively, implementations can use OlmSession.SetLock to provide a custom mutex implementation.
type Store interface {
// Flush ensures that everything in the store is persisted to disk.
// This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately.

View file

@ -1,3 +1,3 @@
package mautrix
const Version = "v0.7.8"
const Version = "v0.7.9"