mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Stop using non-existent device_id field when handling olm messages
This commit is contained in:
parent
5b4a845029
commit
087644889b
5 changed files with 80 additions and 20 deletions
|
|
@ -54,7 +54,7 @@ func (mach *OlmMachine) decryptOlmEvent(evt *event.Event) (*DecryptedOlmEvent, e
|
|||
if !ok {
|
||||
return nil, NotEncryptedForMe
|
||||
}
|
||||
decrypted, err := mach.decryptAndParseOlmCiphertext(evt.Sender, content.DeviceID, content.SenderKey, ownContent.Type, ownContent.Body)
|
||||
decrypted, err := mach.decryptAndParseOlmCiphertext(evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -66,12 +66,12 @@ type OlmEventKeys struct {
|
|||
Ed25519 id.Ed25519 `json:"ed25519"`
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, deviceID id.DeviceID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
|
||||
func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
|
||||
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
|
||||
return nil, UnsupportedOlmMessageType
|
||||
}
|
||||
|
||||
plaintext, err := mach.tryDecryptOlmCiphertext(sender, deviceID, senderKey, olmType, ciphertext)
|
||||
plaintext, err := mach.tryDecryptOlmCiphertext(sender, senderKey, olmType, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -99,7 +99,7 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, deviceID
|
|||
return &olmEvt, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, deviceID id.DeviceID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
mach.olmLock.Lock()
|
||||
defer mach.olmLock.Unlock()
|
||||
|
||||
|
|
@ -107,7 +107,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, deviceID id.De
|
|||
if err != nil {
|
||||
if err == DecryptionFailedWithMatchingSession {
|
||||
mach.Log.Warn("Found matching session yet decryption failed for sender %s with key %s", sender, senderKey)
|
||||
go mach.unwedgeDevice(sender, deviceID, senderKey)
|
||||
go mach.unwedgeDevice(sender, senderKey)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to decrypt olm event: %w", err)
|
||||
}
|
||||
|
|
@ -122,17 +122,17 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, deviceID id.De
|
|||
// New sessions can only be created if it's a prekey message, we can't decrypt the message
|
||||
// if it isn't one at this point in time anymore, so return early.
|
||||
if olmType != id.OlmMsgTypePreKey {
|
||||
go mach.unwedgeDevice(sender, deviceID, senderKey)
|
||||
go mach.unwedgeDevice(sender, senderKey)
|
||||
return nil, DecryptionFailedForNormalMessage
|
||||
}
|
||||
|
||||
mach.Log.Trace("Trying to create inbound session for %s/%s", sender, deviceID)
|
||||
mach.Log.Trace("Trying to create inbound session for %s/%s", sender, senderKey)
|
||||
session, err := mach.createInboundSession(senderKey, ciphertext)
|
||||
if err != nil {
|
||||
go mach.unwedgeDevice(sender, deviceID, senderKey)
|
||||
go mach.unwedgeDevice(sender, senderKey)
|
||||
return nil, fmt.Errorf("failed to create new session from prekey message: %w", err)
|
||||
}
|
||||
mach.Log.Debug("Created inbound olm session %s for %s/%s (sender key: %s)", session.ID(), sender, deviceID, senderKey)
|
||||
mach.Log.Debug("Created inbound olm session %s for %s/%s", session.ID(), sender, senderKey)
|
||||
|
||||
plaintext, err = session.Decrypt(ciphertext, olmType)
|
||||
if err != nil {
|
||||
|
|
@ -191,25 +191,31 @@ func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext
|
|||
|
||||
const MinUnwedgeInterval = 1 * time.Hour
|
||||
|
||||
func (mach *OlmMachine) unwedgeDevice(sender id.UserID, deviceID id.DeviceID, senderKey id.SenderKey) {
|
||||
func (mach *OlmMachine) unwedgeDevice(sender id.UserID, senderKey id.SenderKey) {
|
||||
mach.recentlyUnwedgedLock.Lock()
|
||||
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
|
||||
delta := time.Now().Sub(prevUnwedge)
|
||||
if ok && delta < MinUnwedgeInterval {
|
||||
mach.Log.Debug("Not creating new Olm session with %s/%s, previous recreation was %s ago", sender, deviceID, delta)
|
||||
mach.Log.Debug("Not creating new Olm session with %s/%s, previous recreation was %s ago", sender, senderKey, delta)
|
||||
mach.recentlyUnwedgedLock.Unlock()
|
||||
return
|
||||
}
|
||||
mach.recentlyUnwedged[senderKey] = time.Now()
|
||||
mach.recentlyUnwedgedLock.Unlock()
|
||||
mach.Log.Debug("Creating new Olm session with %s/%s...", sender, deviceID)
|
||||
mach.devicesToUnwedge.Store(senderKey, true)
|
||||
err := mach.SendEncryptedToDevice(&DeviceIdentity{
|
||||
UserID: sender,
|
||||
DeviceID: deviceID,
|
||||
IdentityKey: senderKey,
|
||||
}, event.ToDeviceDummy, event.Content{})
|
||||
|
||||
deviceIdentity, err := mach.GetOrFetchDeviceByKey(sender, senderKey)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to send dummy event to unwedge session with %s/%s: %v", sender, deviceID, err)
|
||||
mach.Log.Error("Failed to find device info by identity key: %v", err)
|
||||
return
|
||||
} else if deviceIdentity == nil {
|
||||
mach.Log.Warn("Didn't find identity of %s/%s, can't unwedge session", sender, senderKey)
|
||||
return
|
||||
}
|
||||
|
||||
mach.Log.Debug("Creating new Olm session with %s/%s (key: %s)", sender, deviceIdentity.DeviceID, senderKey)
|
||||
mach.devicesToUnwedge.Store(senderKey, true)
|
||||
err = mach.SendEncryptedToDevice(deviceIdentity, event.ToDeviceDummy, event.Content{})
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to send dummy event to unwedge session with %s/%s: %v", sender, senderKey, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -333,6 +333,24 @@ func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID)
|
|||
return nil, fmt.Errorf("didn't get any devices for %s", userID)
|
||||
}
|
||||
|
||||
// GetOrFetchDeviceByKey attempts to retrieve the device identity for the device with the given identity key from the
|
||||
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
|
||||
// the given identity key.
|
||||
func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) {
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
|
||||
if err != nil || deviceIdentity != nil {
|
||||
return deviceIdentity, err
|
||||
}
|
||||
mach.Log.Debug("Didn't find identity of %s/%s in crypto store, fetching from server", userID, identityKey)
|
||||
devices := mach.LoadDevices(userID)
|
||||
for _, device := range devices {
|
||||
if device.IdentityKey == identityKey {
|
||||
return device, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// SendEncryptedToDevice sends an Olm-encrypted event to the given user device.
|
||||
func (mach *OlmMachine) SendEncryptedToDevice(device *DeviceIdentity, evtType event.Type, content event.Content) error {
|
||||
if err := mach.createOutboundSessions(map[id.UserID]map[id.DeviceID]*DeviceIdentity{
|
||||
|
|
|
|||
|
|
@ -506,6 +506,25 @@ func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (
|
|||
return &identity, nil
|
||||
}
|
||||
|
||||
// FindDeviceByKey finds a specific device by its sender key.
|
||||
func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) {
|
||||
var identity DeviceIdentity
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT device_id, identity_key, signing_key, trust, deleted, name
|
||||
FROM crypto_device WHERE user_id=$1 AND identity_key=$2`,
|
||||
userID, identityKey,
|
||||
).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
identity.UserID = userID
|
||||
identity.IdentityKey = identityKey
|
||||
return &identity, nil
|
||||
}
|
||||
|
||||
// PutDevice stores a single device for a user, replacing it if it exists already.
|
||||
func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *DeviceIdentity) error {
|
||||
var err error
|
||||
|
|
|
|||
|
|
@ -138,6 +138,8 @@ type Store interface {
|
|||
PutDevice(id.UserID, *DeviceIdentity) error
|
||||
// PutDevices overrides the stored device list for the given user with the given list.
|
||||
PutDevices(id.UserID, map[id.DeviceID]*DeviceIdentity) error
|
||||
// FindDeviceByKey finds a specific device by its identity key.
|
||||
FindDeviceByKey(id.UserID, id.IdentityKey) (*DeviceIdentity, error)
|
||||
// FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists
|
||||
// have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty.
|
||||
FilterTrackedUsers([]id.UserID) []id.UserID
|
||||
|
|
@ -471,6 +473,21 @@ func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*DeviceId
|
|||
return device, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
for _, device := range devices {
|
||||
if device.IdentityKey == identityKey {
|
||||
return device, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) PutDevice(userID id.UserID, device *DeviceIdentity) error {
|
||||
gs.lock.Lock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ type EncryptionEventContent struct {
|
|||
type EncryptedEventContent struct {
|
||||
Algorithm id.Algorithm `json:"algorithm"`
|
||||
SenderKey id.SenderKey `json:"sender_key"`
|
||||
DeviceID id.DeviceID `json:"device_id,omitempty"`
|
||||
SessionID id.SessionID `json:"session_id,omitempty"`
|
||||
DeviceID id.DeviceID `json:"device_id,omitempty"` // Only present for Megolm events
|
||||
SessionID id.SessionID `json:"session_id,omitempty"` // Only present for Megolm events
|
||||
Ciphertext json.RawMessage `json:"ciphertext"`
|
||||
|
||||
MegolmCiphertext []byte `json:"-"`
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue