diff --git a/crypto/cross_sign.go b/crypto/cross_sign.go index 0f25c89a..a158af8a 100644 --- a/crypto/cross_sign.go +++ b/crypto/cross_sign.go @@ -15,6 +15,28 @@ import ( func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) { for userID, userKeys := range crossSigningKeys { + currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) + if err != nil { + mach.Log.Error("Error fetching current cross-signing keys of user %v: %v", userID, err) + } + if currentKeys != nil { + for curKeyUsage, curKey := range currentKeys { + // got a new key with the same usage as an existing key + for _, newKeyUsage := range userKeys.Usage { + if newKeyUsage == curKeyUsage { + if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.String())]; !ok { + // old key is not in the new key map so we drop signatures made by it + if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey); err != nil { + mach.Log.Error("Error deleting old signatures: %v", err) + } else { + mach.Log.Debug("Dropped %v signatures made by key `%v` (%v) as it has been replaced", count, curKey, curKeyUsage) + } + } + } + } + } + } + for _, key := range userKeys.Keys { for _, usage := range userKeys.Usage { mach.Log.Debug("Storing cross-signing key for %v: %v (type %v)", userID, key, usage) @@ -36,7 +58,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mau } } - mach.Log.Debug("Verifying %v with: %v %v %v", userKeys, signUserID, signKeyName, signingKey) + mach.Log.Debug("Verifying with key %v of user %v", signingKey, signUserID) if verified, err := olm.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil { mach.Log.Error("Error while verifying cross-signing keys: %v", err) } else { diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 47813764..bb4e45df 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -121,7 +121,6 @@ func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeU mach.Log.Warn("Didn't get any keys for user %s", userID) } - // TODO delete old signatures by previous x-signing keys if they have been updated mach.storeCrossSigningKeys(resp.MasterKeys, resp.DeviceKeys) mach.storeCrossSigningKeys(resp.SelfSigningKeys, resp.DeviceKeys) mach.storeCrossSigningKeys(resp.UserSigningKeys, resp.DeviceKeys) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 9089d3a2..4d276f0a 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -664,3 +664,16 @@ func (store *SQLCryptoStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, sig _, ok := sigs[signerKey] return ok, nil } + +// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. +func (store *SQLCryptoStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) { + res, err := store.DB.Exec("DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key) + if err != nil { + return 0, err + } + count, err := res.RowsAffected() + if err != nil { + return 0, err + } + return count, nil +} diff --git a/crypto/store.go b/crypto/store.go index 7d0b81da..18f81669 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -151,6 +151,8 @@ type Store interface { GetSignaturesForKeyBy(id.UserID, id.Ed25519, id.UserID) (map[id.Ed25519]string, error) // IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer. IsKeySignedBy(id.UserID, id.Ed25519, id.UserID, id.Ed25519) (bool, error) + // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. + DropSignaturesByKey(id.UserID, id.Ed25519) (int64, error) } type messageIndexKey struct { @@ -569,3 +571,20 @@ func (gs *GobStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id. _, ok := sigs[signerKey] return ok, nil } + +func (gs *GobStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) { + var count int64 + gs.lock.RLock() + for _, userSigs := range gs.KeySignatures { + for _, keySigs := range userSigs { + if signedBySigner, ok := keySigs[userID]; ok { + if _, ok := signedBySigner[key]; ok { + count++ + delete(signedBySigner, key) + } + } + } + } + gs.lock.RUnlock() + return count, nil +} diff --git a/crypto/utils/utils.go b/crypto/utils/utils.go index 2a013170..652b2d33 100644 --- a/crypto/utils/utils.go +++ b/crypto/utils/utils.go @@ -47,7 +47,7 @@ func GenAttachmentA256CTR() (key [AESCTRKeyLength]byte, iv [AESCTRIVLength]byte) panic(err) } - // For some reason we leave the 8 last bytes empty even though AES256-CTR has a 16-byte block size. + // The last 8 bytes of the IV act as the counter in AES-CTR, which means they're left empty here _, err = rand.Read(iv[:8]) if err != nil { panic(err)