diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go index ed0eb0d2..38b3c89f 100644 --- a/bridge/bridgeconfig/config.go +++ b/bridge/bridgeconfig/config.go @@ -175,6 +175,11 @@ type EncryptionConfig struct { PlaintextMentions bool `yaml:"plaintext_mentions"` + DeleteKeys struct { + DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"` + RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"` + } `yaml:"delete_keys"` + VerificationLevels struct { Receive id.TrustState `yaml:"receive"` Send id.TrustState `yaml:"send"` diff --git a/bridge/crypto.go b/bridge/crypto.go index 7dc42c1b..400ad65d 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -90,8 +90,12 @@ func (helper *CryptoHelper) Init() error { stateStore := &cryptoStateStore{helper.bridge} helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, stateStore) helper.mach.AllowKeyShare = helper.allowKeyShare - helper.mach.SendKeysMinTrust = helper.bridge.Config.Bridge.GetEncryptionConfig().VerificationLevels.Receive - helper.mach.PlaintextMentions = helper.bridge.Config.Bridge.GetEncryptionConfig().PlaintextMentions + + encryptionConfig := helper.bridge.Config.Bridge.GetEncryptionConfig() + helper.mach.SendKeysMinTrust = encryptionConfig.VerificationLevels.Receive + helper.mach.PlaintextMentions = encryptionConfig.PlaintextMentions + helper.mach.DeleteOutboundKeysOnAck = encryptionConfig.DeleteKeys.DeleteOutboundOnAck + helper.mach.RatchetKeysOnDecrypt = encryptionConfig.DeleteKeys.RatchetOnDecrypt helper.client.Syncer = &cryptoSyncer{helper.mach} helper.client.Store = helper.store diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 2f2ed13b..0d28be48 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -13,6 +13,8 @@ import ( "fmt" "strings" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -24,6 +26,7 @@ var ( WrongRoom = errors.New("encrypted megolm event is not intended for this room") DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") + RatchetError = errors.New("failed to ratchet session after use") ) type megolmEvent struct { @@ -55,21 +58,9 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if origRoomID, ok := evt.Content.Raw["com.beeper.original_room_id"].(string); ok && strings.HasSuffix(origRoomID, ".local") && strings.HasSuffix(evt.RoomID.String(), ".local") { encryptionRoomID = id.RoomID(origRoomID) } - sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID) + sess, plaintext, messageIndex, err := mach.actuallyDecryptMegolmEvent(ctx, evt, encryptionRoomID, content) if err != nil { - return nil, fmt.Errorf("failed to get group session: %w", err) - } else if sess == nil { - return nil, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) - } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { - return nil, SenderKeyMismatch - } - plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) - if err != nil { - return nil, fmt.Errorf("failed to decrypt megolm event: %w", err) - } else if ok, err = mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { - return nil, fmt.Errorf("failed to check if message index is duplicate: %w", err) - } else if !ok { - return nil, DuplicateMessageIndex + return nil, err } log = log.With().Uint("message_index", messageIndex).Logger() @@ -160,3 +151,81 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event }, }, nil } + +func removeItem(slice []uint, item uint) ([]uint, bool) { + for i, s := range slice { + if s == item { + return append(slice[:i], slice[i+1:]...), true + } + } + return slice, false +} + +func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) { + mach.megolmDecryptLock.Lock() + defer mach.megolmDecryptLock.Unlock() + + sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) + } else if sess == nil { + return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) + } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { + return sess, nil, 0, SenderKeyMismatch + } + plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) + if err != nil { + return sess, nil, 0, fmt.Errorf("failed to decrypt megolm event: %w", err) + } else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { + return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err) + } else if !ok { + return sess, nil, messageIndex, DuplicateMessageIndex + } + + expectedMessageIndex := sess.RatchetSafety.NextIndex + didModify := false + switch { + case messageIndex > expectedMessageIndex: + // When the index jumps, add indices in between to the missed indices list. + for i := expectedMessageIndex; i < messageIndex; i++ { + sess.RatchetSafety.MissedIndices = append(sess.RatchetSafety.MissedIndices, i) + } + fallthrough + case messageIndex == expectedMessageIndex: + // When the index moves forward (to the next one or jumping ahead), update the last received index. + sess.RatchetSafety.NextIndex = messageIndex + 1 + didModify = true + default: + sess.RatchetSafety.MissedIndices, didModify = removeItem(sess.RatchetSafety.MissedIndices, messageIndex) + } + ratchetTargetIndex := uint32(sess.RatchetSafety.NextIndex) + if len(sess.RatchetSafety.MissedIndices) > 0 { + ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0]) + } + ratchetCurrentIndex := sess.Internal.FirstKnownIndex() + log := zerolog.Ctx(ctx).With(). + Uint32("prev_ratchet_index", ratchetCurrentIndex). + Uint32("new_ratchet_index", ratchetTargetIndex). + Uint("next_new_index", sess.RatchetSafety.NextIndex). + Uints("missed_indices", sess.RatchetSafety.MissedIndices). + Logger() + if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt { + if err = sess.RatchetTo(ratchetTargetIndex); err != nil { + log.Err(err).Msg("Failed to ratchet session") + return sess, plaintext, messageIndex, RatchetError + } else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + log.Err(err).Msg("Failed to store ratcheted session") + return sess, plaintext, messageIndex, RatchetError + } else { + log.Debug().Msg("Ratcheted session forward") + } + } else if didModify { + if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + log.Err(err).Msg("Failed to store updated ratchet safety data") + return sess, plaintext, messageIndex, RatchetError + } + } else { + log.Debug().Msg("Ratchet safety data didn't change") + } + return sess, plaintext, messageIndex, nil +} diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index a9ec9a6f..0bf30153 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -82,6 +82,8 @@ func parseMessageIndex(ciphertext []byte) (uint64, error) { // If you use the event.Content struct, make sure you pass a pointer to the struct, // as JSON serialization will not work correctly otherwise. func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) { + mach.megolmEncryptLock.Lock() + defer mach.megolmEncryptLock.Unlock() session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) @@ -136,7 +138,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession { session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID)) signingKey, idKey := mach.account.Keys() - mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key()) + mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages) return session } @@ -158,6 +160,8 @@ func strishArray[T ~string](arr []T) []string { // For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent. // If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error { + mach.megolmEncryptLock.Lock() + defer mach.megolmEncryptLock.Unlock() session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) if err != nil { return fmt.Errorf("failed to get previous outbound group session: %w", err) diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 802f2907..ed66f23b 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -17,6 +17,7 @@ import ( "encoding/json" "errors" "fmt" + "time" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" @@ -108,6 +109,8 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er RoomID: session.RoomID, // TODO should we add something here to mark the signing key as unverified like key requests do? ForwardingChains: session.ForwardingChains, + + ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { diff --git a/crypto/keysharing.go b/crypto/keysharing.go index b4b8dceb..4472316b 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -9,6 +9,7 @@ package crypto import ( "context" + "time" "github.com/rs/zerolog" @@ -150,6 +151,13 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt Msg("Mismatched session ID while creating inbound group session from forward") return false } + config := mach.StateStore.GetEncryptionEvent(content.RoomID) + var maxAge time.Duration + var maxMessages int + if config != nil { + maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond + maxMessages = config.RotationPeriodMessages + } igs := &InboundGroupSession{ Internal: *igsInternal, SigningKey: evt.Keys.Ed25519, @@ -157,6 +165,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt RoomID: content.RoomID, ForwardingChains: append(content.ForwardingKeyChain, evt.SenderKey.String()), id: content.SessionID, + + ReceivedAt: time.Now().UTC(), + MaxAge: maxAge.Milliseconds(), + MaxMessages: maxMessages, } err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs) if err != nil { @@ -298,3 +310,28 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User log.Debug().Msg("Successfully sent forwarded group session") } } + +func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) { + log := mach.machOrContextLog(ctx).With(). + Str("room_id", content.RoomID.String()). + Str("session_id", content.SessionID.String()). + Int("first_message_index", content.FirstMessageIndex). + Logger() + + sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID) + if err != nil { + log.Err(err).Msg("Failed to get group session to check if it should be redacted") + return + } + + isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey + if isInbound && mach.DeleteOutboundKeysOnAck { + log.Debug().Msg("Redacting inbound copy of outbound group session after ack") + err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID) + if err != nil { + log.Err(err).Msg("Failed to redact group session") + } + } else { + log.Debug().Bool("inbound", isInbound).Msg("Received room key ack") + } +} diff --git a/crypto/machine.go b/crypto/machine.go index bf14084f..5cbe7744 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -55,7 +55,9 @@ type OlmMachine struct { recentlyUnwedged map[id.IdentityKey]time.Time recentlyUnwedgedLock sync.Mutex - olmLock sync.Mutex + olmLock sync.Mutex + megolmEncryptLock sync.Mutex + megolmDecryptLock sync.Mutex otkUploadLock sync.Mutex lastOTKUpload time.Time @@ -64,6 +66,9 @@ type OlmMachine struct { crossSigningPubkeys *CrossSigningPublicKeysCache crossSigningPubkeysFetched bool + + DeleteOutboundKeysOnAck bool + RatchetKeysOnDecrypt bool } // StateStore is used by OlmMachine to get room state information that's needed for encryption. @@ -365,6 +370,8 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { return case *event.RoomKeyRequestEventContent: go mach.handleRoomKeyRequest(ctx, evt.Sender, content) + case *event.BeeperRoomKeyAckEventContent: + mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content) // verification cases case *event.VerificationStartEventContent: mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "") @@ -472,9 +479,9 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De return err } -func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string) { +func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int) { log := zerolog.Ctx(ctx) - igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey) + igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages) if err != nil { log.Error().Err(err).Msg("Failed to create inbound group session") return @@ -539,7 +546,14 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve return } - mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey) + config := mach.StateStore.GetEncryptionEvent(content.RoomID) + var maxAge time.Duration + var maxMessages int + if config != nil { + maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond + maxMessages = config.RotationPeriodMessages + } + mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages) } func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) { diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 11eccb2e..5d1b3636 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -121,7 +121,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // store room key in new inbound group session decrypted.Content.ParseRaw(event.ToDeviceRoomKey) roomKeyEvt := decrypted.Content.AsRoomKey() - igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey) + igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0) if err != nil { t.Errorf("Error creating inbound megolm session: %v", err) } diff --git a/crypto/sessions.go b/crypto/sessions.go index 23e4d1d4..355a4660 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -89,6 +89,11 @@ func (session *OlmSession) Decrypt(ciphertext string, msgType id.OlmMsgType) ([] return msg, err } +type RatchetSafety struct { + NextIndex uint `json:"next_index"` + MissedIndices []uint `json:"missed_indices,omitempty"` +} + type InboundGroupSession struct { Internal olm.InboundGroupSession @@ -97,11 +102,16 @@ type InboundGroupSession struct { RoomID id.RoomID ForwardingChains []string + RatchetSafety RatchetSafety + + ReceivedAt time.Time + MaxAge int64 + MaxMessages int id id.SessionID } -func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string) (*InboundGroupSession, error) { +func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int) (*InboundGroupSession, error) { igs, err := olm.NewInboundGroupSession([]byte(sessionKey)) if err != nil { return nil, err @@ -112,6 +122,9 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI SenderKey: senderKey, RoomID: roomID, ForwardingChains: nil, + ReceivedAt: time.Now().UTC(), + MaxAge: maxAge.Milliseconds(), + MaxMessages: maxMessages, }, nil } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index f34bb86d..3d099cfc 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "errors" "fmt" "strings" @@ -254,31 +255,58 @@ func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) return err } +func intishPtr[T int | int64](i T) *T { + if i == 0 { + return nil + } + return &i +} + +func datePtr(t time.Time) *time.Time { + if t.IsZero() { + return nil + } + return &t +} + // PutGroupSession stores an inbound Megolm group session for a room, sender and session. func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) forwardingChains := strings.Join(session.ForwardingChains, ",") - _, err := store.DB.Exec(` - INSERT INTO crypto_megolm_inbound_session - (session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id) - VALUES ($1, $2, $3, $4, $5, $6, $7) + ratchetSafety, err := json.Marshal(&session.RatchetSafety) + if err != nil { + return fmt.Errorf("failed to marshal ratchet safety info: %w", err) + } + _, err = store.DB.Exec(` + INSERT INTO crypto_megolm_inbound_session ( + session_id, sender_key, signing_key, room_id, session, forwarding_chains, + ratchet_safety, received_at, max_age, max_messages, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, - room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains - `, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID) + room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, + ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, + max_age=excluded.max_age, max_messages=excluded.max_messages + `, + sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, + ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages), + store.AccountID, + ) return err } // GetGroupSession retrieves an inbound Megolm group session for a room, sender and session. func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { var signingKey, forwardingChains, withheldCode sql.NullString - var sessionBytes []byte + var sessionBytes, ratchetSafetyBytes []byte + var receivedAt sql.NullTime + var maxAge, maxMessages sql.NullInt64 err := store.DB.QueryRow(` - SELECT signing_key, session, forwarding_chains, withheld_code + SELECT signing_key, session, forwarding_chains, withheld_code, ratchet_safety, received_at, max_age, max_messages FROM crypto_megolm_inbound_session - WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`, + WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, - ).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode) + ).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages) if err == sql.ErrNoRows { return nil, nil } else if err != nil { @@ -295,18 +323,38 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send if forwardingChains.String != "" { chains = strings.Split(forwardingChains.String, ",") } + var rs RatchetSafety + if len(ratchetSafetyBytes) > 0 { + err = json.Unmarshal(ratchetSafetyBytes, &rs) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err) + } + } return &InboundGroupSession{ Internal: *igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: senderKey, RoomID: roomID, ForwardingChains: chains, + RatchetSafety: rs, + ReceivedAt: receivedAt.Time, + MaxAge: maxAge.Int64, + MaxMessages: int(maxMessages.Int64), }, nil } +func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID) error { + _, err := store.DB.Exec(` + UPDATE crypto_megolm_inbound_session + SET withheld_code=$1, withheld_reason='Session redacted', session=NULL, forwarding_chains=NULL, ratchet_safety=NULL + WHERE session_id=$2 AND account_id=$3 + `, event.RoomKeyWithheldBeeperRedacted, sessionID, store.AccountID) + return err +} + func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error { - _, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)", - content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID) + _, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6. $7)", + content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID) return err } @@ -336,8 +384,10 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I for rows.Next() { var roomID id.RoomID var signingKey, senderKey, forwardingChains sql.NullString - var sessionBytes []byte - err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains) + var sessionBytes, ratchetSafetyBytes []byte + var receivedAt sql.NullTime + var maxAge, maxMessages sql.NullInt64 + err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages) if err != nil { return } @@ -350,12 +400,23 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I if forwardingChains.String != "" { chains = strings.Split(forwardingChains.String, ",") } + var rs RatchetSafety + if len(ratchetSafetyBytes) > 0 { + err = json.Unmarshal(ratchetSafetyBytes, &rs) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err) + } + } result = append(result, &InboundGroupSession{ Internal: *igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, ForwardingChains: chains, + RatchetSafety: rs, + ReceivedAt: receivedAt.Time, + MaxAge: maxAge.Int64, + MaxMessages: int(maxMessages.Int64), }) } return @@ -363,7 +424,7 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) { rows, err := store.DB.Query(` - SELECT room_id, signing_key, sender_key, session, forwarding_chains + SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) @@ -377,7 +438,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*Inbou func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { rows, err := store.DB.Query(` - SELECT room_id, signing_key, sender_key, session, forwarding_chains + SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`, store.AccountID, ) diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 44bdb5a1..7a3633f6 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v8: Latest revision +-- v0 -> v10: Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -52,6 +52,10 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( forwarding_chains bytea, withheld_code TEXT, withheld_reason TEXT, + ratchet_safety jsonb, + received_at timestamp, + max_age BIGINT, + max_messages INTEGER, PRIMARY KEY (account_id, session_id) ); diff --git a/crypto/sql_store_upgrade/10-mark-ratchetable-keys.sql b/crypto/sql_store_upgrade/10-mark-ratchetable-keys.sql new file mode 100644 index 00000000..36a7b25a --- /dev/null +++ b/crypto/sql_store_upgrade/10-mark-ratchetable-keys.sql @@ -0,0 +1,5 @@ +-- v10: Add flag for megolm sessions to mark them as safe to delete +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb; +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp; +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT; +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER; diff --git a/crypto/store.go b/crypto/store.go index 56881041..9253824a 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -54,6 +54,8 @@ type Store interface { // (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the // ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details. GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error) + // RedactGroupSession removes the session data for the given inbound Megolm session from the store. + RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID) error // PutWithheldGroupSession tells the store that a specific Megolm session was withheld. PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error // GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession. @@ -262,6 +264,14 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, return session, nil } +func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) error { + gs.lock.Lock() + delete(gs.getGroupSessions(roomID, senderKey), sessionID) + err := gs.save() + gs.lock.Unlock() + return err +} + func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent { room, ok := gs.WithheldGroupSessions[roomID] if !ok { diff --git a/event/beeper.go b/event/beeper.go index 2ee72073..926b5d07 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -49,3 +49,9 @@ type BeeperRetryMetadata struct { RetryCount int `json:"retry_count"` // last_retry is also present, but not used by bridges } + +type BeeperRoomKeyAckEventContent struct { + RoomID id.RoomID `json:"room_id"` + SessionID id.SessionID `json:"session_id"` + FirstMessageIndex int `json:"first_message_index"` +} diff --git a/event/content.go b/event/content.go index 5624fd59..24c1c193 100644 --- a/event/content.go +++ b/event/content.go @@ -80,6 +80,8 @@ var TypeMap = map[Type]reflect.Type{ ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}), + ToDeviceBeeperRoomKeyAck: reflect.TypeOf(BeeperRoomKeyAckEventContent{}), + CallInvite: reflect.TypeOf(CallInviteEventContent{}), CallCandidates: reflect.TypeOf(CallCandidatesEventContent{}), CallAnswer: reflect.TypeOf(CallAnswerEventContent{}), diff --git a/event/encryption.go b/event/encryption.go index 2506ad57..ab4290b6 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -135,6 +135,8 @@ const ( RoomKeyWithheldUnauthorized RoomKeyWithheldCode = "m.unauthorised" RoomKeyWithheldUnavailable RoomKeyWithheldCode = "m.unavailable" RoomKeyWithheldNoOlmSession RoomKeyWithheldCode = "m.no_olm" + + RoomKeyWithheldBeeperRedacted RoomKeyWithheldCode = "com.beeper.redacted" ) type RoomKeyWithheldEventContent struct { diff --git a/event/type.go b/event/type.go index 050b17e9..9ac64b6d 100644 --- a/event/type.go +++ b/event/type.go @@ -124,7 +124,8 @@ func (et *Type) GuessClass() TypeClass { CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type: return MessageEventType - case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type: + case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, + ToDeviceBeeperRoomKeyAck.Type: return ToDeviceEventType default: return UnknownEventType @@ -253,4 +254,6 @@ var ( ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType} ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType} + + ToDeviceBeeperRoomKeyAck = Type{"com.beeper.room_key.ack", ToDeviceEventType} )