Stop using non-existent device_id field when handling olm messages

This commit is contained in:
Tulir Asokan 2021-12-03 17:47:41 +02:00
commit 087644889b
5 changed files with 80 additions and 20 deletions

View file

@ -54,7 +54,7 @@ func (mach *OlmMachine) decryptOlmEvent(evt *event.Event) (*DecryptedOlmEvent, e
if !ok {
return nil, NotEncryptedForMe
}
decrypted, err := mach.decryptAndParseOlmCiphertext(evt.Sender, content.DeviceID, content.SenderKey, ownContent.Type, ownContent.Body)
decrypted, err := mach.decryptAndParseOlmCiphertext(evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil {
return nil, err
}
@ -66,12 +66,12 @@ type OlmEventKeys struct {
Ed25519 id.Ed25519 `json:"ed25519"`
}
func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, deviceID id.DeviceID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
return nil, UnsupportedOlmMessageType
}
plaintext, err := mach.tryDecryptOlmCiphertext(sender, deviceID, senderKey, olmType, ciphertext)
plaintext, err := mach.tryDecryptOlmCiphertext(sender, senderKey, olmType, ciphertext)
if err != nil {
return nil, err
}
@ -99,7 +99,7 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, deviceID
return &olmEvt, nil
}
func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, deviceID id.DeviceID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
mach.olmLock.Lock()
defer mach.olmLock.Unlock()
@ -107,7 +107,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, deviceID id.De
if err != nil {
if err == DecryptionFailedWithMatchingSession {
mach.Log.Warn("Found matching session yet decryption failed for sender %s with key %s", sender, senderKey)
go mach.unwedgeDevice(sender, deviceID, senderKey)
go mach.unwedgeDevice(sender, senderKey)
}
return nil, fmt.Errorf("failed to decrypt olm event: %w", err)
}
@ -122,17 +122,17 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, deviceID id.De
// New sessions can only be created if it's a prekey message, we can't decrypt the message
// if it isn't one at this point in time anymore, so return early.
if olmType != id.OlmMsgTypePreKey {
go mach.unwedgeDevice(sender, deviceID, senderKey)
go mach.unwedgeDevice(sender, senderKey)
return nil, DecryptionFailedForNormalMessage
}
mach.Log.Trace("Trying to create inbound session for %s/%s", sender, deviceID)
mach.Log.Trace("Trying to create inbound session for %s/%s", sender, senderKey)
session, err := mach.createInboundSession(senderKey, ciphertext)
if err != nil {
go mach.unwedgeDevice(sender, deviceID, senderKey)
go mach.unwedgeDevice(sender, senderKey)
return nil, fmt.Errorf("failed to create new session from prekey message: %w", err)
}
mach.Log.Debug("Created inbound olm session %s for %s/%s (sender key: %s)", session.ID(), sender, deviceID, senderKey)
mach.Log.Debug("Created inbound olm session %s for %s/%s", session.ID(), sender, senderKey)
plaintext, err = session.Decrypt(ciphertext, olmType)
if err != nil {
@ -191,25 +191,31 @@ func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext
const MinUnwedgeInterval = 1 * time.Hour
func (mach *OlmMachine) unwedgeDevice(sender id.UserID, deviceID id.DeviceID, senderKey id.SenderKey) {
func (mach *OlmMachine) unwedgeDevice(sender id.UserID, senderKey id.SenderKey) {
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
delta := time.Now().Sub(prevUnwedge)
if ok && delta < MinUnwedgeInterval {
mach.Log.Debug("Not creating new Olm session with %s/%s, previous recreation was %s ago", sender, deviceID, delta)
mach.Log.Debug("Not creating new Olm session with %s/%s, previous recreation was %s ago", sender, senderKey, delta)
mach.recentlyUnwedgedLock.Unlock()
return
}
mach.recentlyUnwedged[senderKey] = time.Now()
mach.recentlyUnwedgedLock.Unlock()
mach.Log.Debug("Creating new Olm session with %s/%s...", sender, deviceID)
mach.devicesToUnwedge.Store(senderKey, true)
err := mach.SendEncryptedToDevice(&DeviceIdentity{
UserID: sender,
DeviceID: deviceID,
IdentityKey: senderKey,
}, event.ToDeviceDummy, event.Content{})
deviceIdentity, err := mach.GetOrFetchDeviceByKey(sender, senderKey)
if err != nil {
mach.Log.Error("Failed to send dummy event to unwedge session with %s/%s: %v", sender, deviceID, err)
mach.Log.Error("Failed to find device info by identity key: %v", err)
return
} else if deviceIdentity == nil {
mach.Log.Warn("Didn't find identity of %s/%s, can't unwedge session", sender, senderKey)
return
}
mach.Log.Debug("Creating new Olm session with %s/%s (key: %s)", sender, deviceIdentity.DeviceID, senderKey)
mach.devicesToUnwedge.Store(senderKey, true)
err = mach.SendEncryptedToDevice(deviceIdentity, event.ToDeviceDummy, event.Content{})
if err != nil {
mach.Log.Error("Failed to send dummy event to unwedge session with %s/%s: %v", sender, senderKey, err)
}
}

View file

@ -333,6 +333,24 @@ func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID)
return nil, fmt.Errorf("didn't get any devices for %s", userID)
}
// GetOrFetchDeviceByKey attempts to retrieve the device identity for the device with the given identity key from the
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
// the given identity key.
func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
if err != nil || deviceIdentity != nil {
return deviceIdentity, err
}
mach.Log.Debug("Didn't find identity of %s/%s in crypto store, fetching from server", userID, identityKey)
devices := mach.LoadDevices(userID)
for _, device := range devices {
if device.IdentityKey == identityKey {
return device, nil
}
}
return nil, nil
}
// SendEncryptedToDevice sends an Olm-encrypted event to the given user device.
func (mach *OlmMachine) SendEncryptedToDevice(device *DeviceIdentity, evtType event.Type, content event.Content) error {
if err := mach.createOutboundSessions(map[id.UserID]map[id.DeviceID]*DeviceIdentity{

View file

@ -506,6 +506,25 @@ func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (
return &identity, nil
}
// FindDeviceByKey finds a specific device by its sender key.
func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) {
var identity DeviceIdentity
err := store.DB.QueryRow(`
SELECT device_id, identity_key, signing_key, trust, deleted, name
FROM crypto_device WHERE user_id=$1 AND identity_key=$2`,
userID, identityKey,
).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
identity.UserID = userID
identity.IdentityKey = identityKey
return &identity, nil
}
// PutDevice stores a single device for a user, replacing it if it exists already.
func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *DeviceIdentity) error {
var err error

View file

@ -138,6 +138,8 @@ type Store interface {
PutDevice(id.UserID, *DeviceIdentity) error
// PutDevices overrides the stored device list for the given user with the given list.
PutDevices(id.UserID, map[id.DeviceID]*DeviceIdentity) error
// FindDeviceByKey finds a specific device by its identity key.
FindDeviceByKey(id.UserID, id.IdentityKey) (*DeviceIdentity, error)
// FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists
// have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty.
FilterTrackedUsers([]id.UserID) []id.UserID
@ -471,6 +473,21 @@ func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*DeviceId
return device, nil
}
func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
if !ok {
return nil, nil
}
for _, device := range devices {
if device.IdentityKey == identityKey {
return device, nil
}
}
return nil, nil
}
func (gs *GobStore) PutDevice(userID id.UserID, device *DeviceIdentity) error {
gs.lock.Lock()
devices, ok := gs.Devices[userID]

View file

@ -28,8 +28,8 @@ type EncryptionEventContent struct {
type EncryptedEventContent struct {
Algorithm id.Algorithm `json:"algorithm"`
SenderKey id.SenderKey `json:"sender_key"`
DeviceID id.DeviceID `json:"device_id,omitempty"`
SessionID id.SessionID `json:"session_id,omitempty"`
DeviceID id.DeviceID `json:"device_id,omitempty"` // Only present for Megolm events
SessionID id.SessionID `json:"session_id,omitempty"` // Only present for Megolm events
Ciphertext json.RawMessage `json:"ciphertext"`
MegolmCiphertext []byte `json:"-"`