Ratchet inbound sessions on decrypt and delete outbound on ack

This commit is contained in:
Tulir Asokan 2023-04-10 16:18:03 +03:00
commit 20df20d25a
17 changed files with 283 additions and 41 deletions

View file

@ -175,6 +175,11 @@ type EncryptionConfig struct {
PlaintextMentions bool `yaml:"plaintext_mentions"`
DeleteKeys struct {
DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"`
RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"`
} `yaml:"delete_keys"`
VerificationLevels struct {
Receive id.TrustState `yaml:"receive"`
Send id.TrustState `yaml:"send"`

View file

@ -90,8 +90,12 @@ func (helper *CryptoHelper) Init() error {
stateStore := &cryptoStateStore{helper.bridge}
helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, stateStore)
helper.mach.AllowKeyShare = helper.allowKeyShare
helper.mach.SendKeysMinTrust = helper.bridge.Config.Bridge.GetEncryptionConfig().VerificationLevels.Receive
helper.mach.PlaintextMentions = helper.bridge.Config.Bridge.GetEncryptionConfig().PlaintextMentions
encryptionConfig := helper.bridge.Config.Bridge.GetEncryptionConfig()
helper.mach.SendKeysMinTrust = encryptionConfig.VerificationLevels.Receive
helper.mach.PlaintextMentions = encryptionConfig.PlaintextMentions
helper.mach.DeleteOutboundKeysOnAck = encryptionConfig.DeleteKeys.DeleteOutboundOnAck
helper.mach.RatchetKeysOnDecrypt = encryptionConfig.DeleteKeys.RatchetOnDecrypt
helper.client.Syncer = &cryptoSyncer{helper.mach}
helper.client.Store = helper.store

View file

@ -13,6 +13,8 @@ import (
"fmt"
"strings"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@ -24,6 +26,7 @@ var (
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
RatchetError = errors.New("failed to ratchet session after use")
)
type megolmEvent struct {
@ -55,21 +58,9 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
if origRoomID, ok := evt.Content.Raw["com.beeper.original_room_id"].(string); ok && strings.HasSuffix(origRoomID, ".local") && strings.HasSuffix(evt.RoomID.String(), ".local") {
encryptionRoomID = id.RoomID(origRoomID)
}
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
sess, plaintext, messageIndex, err := mach.actuallyDecryptMegolmEvent(ctx, evt, encryptionRoomID, content)
if err != nil {
return nil, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
return nil, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return nil, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
return nil, fmt.Errorf("failed to decrypt megolm event: %w", err)
} else if ok, err = mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return nil, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
return nil, DuplicateMessageIndex
return nil, err
}
log = log.With().Uint("message_index", messageIndex).Logger()
@ -160,3 +151,81 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
},
}, nil
}
func removeItem(slice []uint, item uint) ([]uint, bool) {
for i, s := range slice {
if s == item {
return append(slice[:i], slice[i+1:]...), true
}
}
return slice, false
}
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
mach.megolmDecryptLock.Lock()
defer mach.megolmDecryptLock.Unlock()
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return sess, nil, 0, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
return sess, nil, 0, fmt.Errorf("failed to decrypt megolm event: %w", err)
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
return sess, nil, messageIndex, DuplicateMessageIndex
}
expectedMessageIndex := sess.RatchetSafety.NextIndex
didModify := false
switch {
case messageIndex > expectedMessageIndex:
// When the index jumps, add indices in between to the missed indices list.
for i := expectedMessageIndex; i < messageIndex; i++ {
sess.RatchetSafety.MissedIndices = append(sess.RatchetSafety.MissedIndices, i)
}
fallthrough
case messageIndex == expectedMessageIndex:
// When the index moves forward (to the next one or jumping ahead), update the last received index.
sess.RatchetSafety.NextIndex = messageIndex + 1
didModify = true
default:
sess.RatchetSafety.MissedIndices, didModify = removeItem(sess.RatchetSafety.MissedIndices, messageIndex)
}
ratchetTargetIndex := uint32(sess.RatchetSafety.NextIndex)
if len(sess.RatchetSafety.MissedIndices) > 0 {
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
}
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
log := zerolog.Ctx(ctx).With().
Uint32("prev_ratchet_index", ratchetCurrentIndex).
Uint32("new_ratchet_index", ratchetTargetIndex).
Uint("next_new_index", sess.RatchetSafety.NextIndex).
Uints("missed_indices", sess.RatchetSafety.MissedIndices).
Logger()
if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
return sess, plaintext, messageIndex, RatchetError
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Debug().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
return sess, plaintext, messageIndex, RatchetError
}
} else {
log.Debug().Msg("Ratchet safety data didn't change")
}
return sess, plaintext, messageIndex, nil
}

View file

@ -82,6 +82,8 @@ func parseMessageIndex(ciphertext []byte) (uint64, error) {
// If you use the event.Content struct, make sure you pass a pointer to the struct,
// as JSON serialization will not work correctly otherwise.
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
@ -136,7 +138,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key())
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages)
return session
}
@ -158,6 +160,8 @@ func strishArray[T ~string](arr []T) []string {
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
// If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)

View file

@ -17,6 +17,7 @@ import (
"encoding/json"
"errors"
"fmt"
"time"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
@ -108,6 +109,8 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
RoomID: session.RoomID,
// TODO should we add something here to mark the signing key as unverified like key requests do?
ForwardingChains: session.ForwardingChains,
ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {

View file

@ -9,6 +9,7 @@ package crypto
import (
"context"
"time"
"github.com/rs/zerolog"
@ -150,6 +151,13 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
Msg("Mismatched session ID while creating inbound group session from forward")
return false
}
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
var maxAge time.Duration
var maxMessages int
if config != nil {
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
maxMessages = config.RotationPeriodMessages
}
igs := &InboundGroupSession{
Internal: *igsInternal,
SigningKey: evt.Keys.Ed25519,
@ -157,6 +165,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
RoomID: content.RoomID,
ForwardingChains: append(content.ForwardingKeyChain, evt.SenderKey.String()),
id: content.SessionID,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
}
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
if err != nil {
@ -298,3 +310,28 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
log.Debug().Msg("Successfully sent forwarded group session")
}
}
func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) {
log := mach.machOrContextLog(ctx).With().
Str("room_id", content.RoomID.String()).
Str("session_id", content.SessionID.String()).
Int("first_message_index", content.FirstMessageIndex).
Logger()
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
if err != nil {
log.Err(err).Msg("Failed to get group session to check if it should be redacted")
return
}
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
if isInbound && mach.DeleteOutboundKeysOnAck {
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID)
if err != nil {
log.Err(err).Msg("Failed to redact group session")
}
} else {
log.Debug().Bool("inbound", isInbound).Msg("Received room key ack")
}
}

View file

@ -55,7 +55,9 @@ type OlmMachine struct {
recentlyUnwedged map[id.IdentityKey]time.Time
recentlyUnwedgedLock sync.Mutex
olmLock sync.Mutex
olmLock sync.Mutex
megolmEncryptLock sync.Mutex
megolmDecryptLock sync.Mutex
otkUploadLock sync.Mutex
lastOTKUpload time.Time
@ -64,6 +66,9 @@ type OlmMachine struct {
crossSigningPubkeys *CrossSigningPublicKeysCache
crossSigningPubkeysFetched bool
DeleteOutboundKeysOnAck bool
RatchetKeysOnDecrypt bool
}
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
@ -365,6 +370,8 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
return
case *event.RoomKeyRequestEventContent:
go mach.handleRoomKeyRequest(ctx, evt.Sender, content)
case *event.BeeperRoomKeyAckEventContent:
mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content)
// verification cases
case *event.VerificationStartEventContent:
mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "")
@ -472,9 +479,9 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
return err
}
func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string) {
func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int) {
log := zerolog.Ctx(ctx)
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey)
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages)
if err != nil {
log.Error().Err(err).Msg("Failed to create inbound group session")
return
@ -539,7 +546,14 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
return
}
mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey)
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
var maxAge time.Duration
var maxMessages int
if config != nil {
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
maxMessages = config.RotationPeriodMessages
}
mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages)
}
func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) {

View file

@ -121,7 +121,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
// store room key in new inbound group session
decrypted.Content.ParseRaw(event.ToDeviceRoomKey)
roomKeyEvt := decrypted.Content.AsRoomKey()
igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey)
igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0)
if err != nil {
t.Errorf("Error creating inbound megolm session: %v", err)
}

View file

@ -89,6 +89,11 @@ func (session *OlmSession) Decrypt(ciphertext string, msgType id.OlmMsgType) ([]
return msg, err
}
type RatchetSafety struct {
NextIndex uint `json:"next_index"`
MissedIndices []uint `json:"missed_indices,omitempty"`
}
type InboundGroupSession struct {
Internal olm.InboundGroupSession
@ -97,11 +102,16 @@ type InboundGroupSession struct {
RoomID id.RoomID
ForwardingChains []string
RatchetSafety RatchetSafety
ReceivedAt time.Time
MaxAge int64
MaxMessages int
id id.SessionID
}
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string) (*InboundGroupSession, error) {
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int) (*InboundGroupSession, error) {
igs, err := olm.NewInboundGroupSession([]byte(sessionKey))
if err != nil {
return nil, err
@ -112,6 +122,9 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
SenderKey: senderKey,
RoomID: roomID,
ForwardingChains: nil,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
}, nil
}

View file

@ -10,6 +10,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
@ -254,31 +255,58 @@ func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession)
return err
}
func intishPtr[T int | int64](i T) *T {
if i == 0 {
return nil
}
return &i
}
func datePtr(t time.Time) *time.Time {
if t.IsZero() {
return nil
}
return &t
}
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
forwardingChains := strings.Join(session.ForwardingChains, ",")
_, err := store.DB.Exec(`
INSERT INTO crypto_megolm_inbound_session
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
if err != nil {
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
}
_, err = store.DB.Exec(`
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains
`, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID)
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages
`,
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
store.AccountID,
)
return err
}
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
var signingKey, forwardingChains, withheldCode sql.NullString
var sessionBytes []byte
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
err := store.DB.QueryRow(`
SELECT signing_key, session, forwarding_chains, withheld_code
SELECT signing_key, session, forwarding_chains, withheld_code, ratchet_safety, received_at, max_age, max_messages
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode)
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
@ -295,18 +323,38 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
return &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: senderKey,
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
}, nil
}
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID) error {
_, err := store.DB.Exec(`
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason='Session redacted', session=NULL, forwarding_chains=NULL, ratchet_safety=NULL
WHERE session_id=$2 AND account_id=$3
`, event.RoomKeyWithheldBeeperRedacted, sessionID, store.AccountID)
return err
}
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID)
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6. $7)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
return err
}
@ -336,8 +384,10 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
for rows.Next() {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
var sessionBytes []byte
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains)
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages)
if err != nil {
return
}
@ -350,12 +400,23 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
result = append(result, &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
})
}
return
@ -363,7 +424,7 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
@ -377,7 +438,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*Inbou
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
store.AccountID,
)

View file

@ -1,4 +1,4 @@
-- v0 -> v8: Latest revision
-- v0 -> v10: Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
@ -52,6 +52,10 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
forwarding_chains bytea,
withheld_code TEXT,
withheld_reason TEXT,
ratchet_safety jsonb,
received_at timestamp,
max_age BIGINT,
max_messages INTEGER,
PRIMARY KEY (account_id, session_id)
);

View file

@ -0,0 +1,5 @@
-- v10: Add flag for megolm sessions to mark them as safe to delete
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER;

View file

@ -54,6 +54,8 @@ type Store interface {
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID) error
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
@ -262,6 +264,14 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey,
return session, nil
}
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) error {
gs.lock.Lock()
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
err := gs.save()
gs.lock.Unlock()
return err
}
func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
room, ok := gs.WithheldGroupSessions[roomID]
if !ok {

View file

@ -49,3 +49,9 @@ type BeeperRetryMetadata struct {
RetryCount int `json:"retry_count"`
// last_retry is also present, but not used by bridges
}
type BeeperRoomKeyAckEventContent struct {
RoomID id.RoomID `json:"room_id"`
SessionID id.SessionID `json:"session_id"`
FirstMessageIndex int `json:"first_message_index"`
}

View file

@ -80,6 +80,8 @@ var TypeMap = map[Type]reflect.Type{
ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}),
ToDeviceBeeperRoomKeyAck: reflect.TypeOf(BeeperRoomKeyAckEventContent{}),
CallInvite: reflect.TypeOf(CallInviteEventContent{}),
CallCandidates: reflect.TypeOf(CallCandidatesEventContent{}),
CallAnswer: reflect.TypeOf(CallAnswerEventContent{}),

View file

@ -135,6 +135,8 @@ const (
RoomKeyWithheldUnauthorized RoomKeyWithheldCode = "m.unauthorised"
RoomKeyWithheldUnavailable RoomKeyWithheldCode = "m.unavailable"
RoomKeyWithheldNoOlmSession RoomKeyWithheldCode = "m.no_olm"
RoomKeyWithheldBeeperRedacted RoomKeyWithheldCode = "com.beeper.redacted"
)
type RoomKeyWithheldEventContent struct {

View file

@ -124,7 +124,8 @@ func (et *Type) GuessClass() TypeClass {
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type:
return MessageEventType
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type:
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
ToDeviceBeeperRoomKeyAck.Type:
return ToDeviceEventType
default:
return UnknownEventType
@ -253,4 +254,6 @@ var (
ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType}
ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType}
ToDeviceBeeperRoomKeyAck = Type{"com.beeper.room_key.ack", ToDeviceEventType}
)