mirror of
https://mau.dev/mautrix/go.git
synced 2026-03-14 14:25:53 +01:00
Ratchet inbound sessions on decrypt and delete outbound on ack
This commit is contained in:
parent
6e268751db
commit
20df20d25a
17 changed files with 283 additions and 41 deletions
|
|
@ -175,6 +175,11 @@ type EncryptionConfig struct {
|
|||
|
||||
PlaintextMentions bool `yaml:"plaintext_mentions"`
|
||||
|
||||
DeleteKeys struct {
|
||||
DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"`
|
||||
RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"`
|
||||
} `yaml:"delete_keys"`
|
||||
|
||||
VerificationLevels struct {
|
||||
Receive id.TrustState `yaml:"receive"`
|
||||
Send id.TrustState `yaml:"send"`
|
||||
|
|
|
|||
|
|
@ -90,8 +90,12 @@ func (helper *CryptoHelper) Init() error {
|
|||
stateStore := &cryptoStateStore{helper.bridge}
|
||||
helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, stateStore)
|
||||
helper.mach.AllowKeyShare = helper.allowKeyShare
|
||||
helper.mach.SendKeysMinTrust = helper.bridge.Config.Bridge.GetEncryptionConfig().VerificationLevels.Receive
|
||||
helper.mach.PlaintextMentions = helper.bridge.Config.Bridge.GetEncryptionConfig().PlaintextMentions
|
||||
|
||||
encryptionConfig := helper.bridge.Config.Bridge.GetEncryptionConfig()
|
||||
helper.mach.SendKeysMinTrust = encryptionConfig.VerificationLevels.Receive
|
||||
helper.mach.PlaintextMentions = encryptionConfig.PlaintextMentions
|
||||
helper.mach.DeleteOutboundKeysOnAck = encryptionConfig.DeleteKeys.DeleteOutboundOnAck
|
||||
helper.mach.RatchetKeysOnDecrypt = encryptionConfig.DeleteKeys.RatchetOnDecrypt
|
||||
|
||||
helper.client.Syncer = &cryptoSyncer{helper.mach}
|
||||
helper.client.Store = helper.store
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
|
@ -24,6 +26,7 @@ var (
|
|||
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
|
||||
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
|
||||
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
|
||||
RatchetError = errors.New("failed to ratchet session after use")
|
||||
)
|
||||
|
||||
type megolmEvent struct {
|
||||
|
|
@ -55,21 +58,9 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
|||
if origRoomID, ok := evt.Content.Raw["com.beeper.original_room_id"].(string); ok && strings.HasSuffix(origRoomID, ".local") && strings.HasSuffix(evt.RoomID.String(), ".local") {
|
||||
encryptionRoomID = id.RoomID(origRoomID)
|
||||
}
|
||||
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
sess, plaintext, messageIndex, err := mach.actuallyDecryptMegolmEvent(ctx, evt, encryptionRoomID, content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
return nil, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
|
||||
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
|
||||
return nil, SenderKeyMismatch
|
||||
}
|
||||
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt megolm event: %w", err)
|
||||
} else if ok, err = mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
||||
return nil, fmt.Errorf("failed to check if message index is duplicate: %w", err)
|
||||
} else if !ok {
|
||||
return nil, DuplicateMessageIndex
|
||||
return nil, err
|
||||
}
|
||||
log = log.With().Uint("message_index", messageIndex).Logger()
|
||||
|
||||
|
|
@ -160,3 +151,81 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
|||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func removeItem(slice []uint, item uint) ([]uint, bool) {
|
||||
for i, s := range slice {
|
||||
if s == item {
|
||||
return append(slice[:i], slice[i+1:]...), true
|
||||
}
|
||||
}
|
||||
return slice, false
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
|
||||
mach.megolmDecryptLock.Lock()
|
||||
defer mach.megolmDecryptLock.Unlock()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
|
||||
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
|
||||
return sess, nil, 0, SenderKeyMismatch
|
||||
}
|
||||
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
|
||||
if err != nil {
|
||||
return sess, nil, 0, fmt.Errorf("failed to decrypt megolm event: %w", err)
|
||||
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
||||
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
|
||||
} else if !ok {
|
||||
return sess, nil, messageIndex, DuplicateMessageIndex
|
||||
}
|
||||
|
||||
expectedMessageIndex := sess.RatchetSafety.NextIndex
|
||||
didModify := false
|
||||
switch {
|
||||
case messageIndex > expectedMessageIndex:
|
||||
// When the index jumps, add indices in between to the missed indices list.
|
||||
for i := expectedMessageIndex; i < messageIndex; i++ {
|
||||
sess.RatchetSafety.MissedIndices = append(sess.RatchetSafety.MissedIndices, i)
|
||||
}
|
||||
fallthrough
|
||||
case messageIndex == expectedMessageIndex:
|
||||
// When the index moves forward (to the next one or jumping ahead), update the last received index.
|
||||
sess.RatchetSafety.NextIndex = messageIndex + 1
|
||||
didModify = true
|
||||
default:
|
||||
sess.RatchetSafety.MissedIndices, didModify = removeItem(sess.RatchetSafety.MissedIndices, messageIndex)
|
||||
}
|
||||
ratchetTargetIndex := uint32(sess.RatchetSafety.NextIndex)
|
||||
if len(sess.RatchetSafety.MissedIndices) > 0 {
|
||||
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
|
||||
}
|
||||
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Uint32("prev_ratchet_index", ratchetCurrentIndex).
|
||||
Uint32("new_ratchet_index", ratchetTargetIndex).
|
||||
Uint("next_new_index", sess.RatchetSafety.NextIndex).
|
||||
Uints("missed_indices", sess.RatchetSafety.MissedIndices).
|
||||
Logger()
|
||||
if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
|
||||
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
|
||||
log.Err(err).Msg("Failed to ratchet session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store ratcheted session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Debug().Msg("Ratcheted session forward")
|
||||
}
|
||||
} else if didModify {
|
||||
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store updated ratchet safety data")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
}
|
||||
} else {
|
||||
log.Debug().Msg("Ratchet safety data didn't change")
|
||||
}
|
||||
return sess, plaintext, messageIndex, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,6 +82,8 @@ func parseMessageIndex(ciphertext []byte) (uint64, error) {
|
|||
// If you use the event.Content struct, make sure you pass a pointer to the struct,
|
||||
// as JSON serialization will not work correctly otherwise.
|
||||
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
|
||||
mach.megolmEncryptLock.Lock()
|
||||
defer mach.megolmEncryptLock.Unlock()
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
|
||||
|
|
@ -136,7 +138,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
|
|||
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
|
||||
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
|
||||
signingKey, idKey := mach.account.Keys()
|
||||
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key())
|
||||
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages)
|
||||
return session
|
||||
}
|
||||
|
||||
|
|
@ -158,6 +160,8 @@ func strishArray[T ~string](arr []T) []string {
|
|||
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
|
||||
// If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset
|
||||
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
|
||||
mach.megolmEncryptLock.Lock()
|
||||
defer mach.megolmEncryptLock.Unlock()
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get previous outbound group session: %w", err)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
|
@ -108,6 +109,8 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
|
|||
RoomID: session.RoomID,
|
||||
// TODO should we add something here to mark the signing key as unverified like key requests do?
|
||||
ForwardingChains: session.ForwardingChains,
|
||||
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
}
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
|
||||
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ package crypto
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
|
|
@ -150,6 +151,13 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
|||
Msg("Mismatched session ID while creating inbound group session from forward")
|
||||
return false
|
||||
}
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
|
||||
maxMessages = config.RotationPeriodMessages
|
||||
}
|
||||
igs := &InboundGroupSession{
|
||||
Internal: *igsInternal,
|
||||
SigningKey: evt.Keys.Ed25519,
|
||||
|
|
@ -157,6 +165,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
|||
RoomID: content.RoomID,
|
||||
ForwardingChains: append(content.ForwardingKeyChain, evt.SenderKey.String()),
|
||||
id: content.SessionID,
|
||||
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
MaxAge: maxAge.Milliseconds(),
|
||||
MaxMessages: maxMessages,
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
|
||||
if err != nil {
|
||||
|
|
@ -298,3 +310,28 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
|
|||
log.Debug().Msg("Successfully sent forwarded group session")
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) {
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Str("room_id", content.RoomID.String()).
|
||||
Str("session_id", content.SessionID.String()).
|
||||
Int("first_message_index", content.FirstMessageIndex).
|
||||
Logger()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to get group session to check if it should be redacted")
|
||||
return
|
||||
}
|
||||
|
||||
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
|
||||
if isInbound && mach.DeleteOutboundKeysOnAck {
|
||||
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
|
||||
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact group session")
|
||||
}
|
||||
} else {
|
||||
log.Debug().Bool("inbound", isInbound).Msg("Received room key ack")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,7 +55,9 @@ type OlmMachine struct {
|
|||
recentlyUnwedged map[id.IdentityKey]time.Time
|
||||
recentlyUnwedgedLock sync.Mutex
|
||||
|
||||
olmLock sync.Mutex
|
||||
olmLock sync.Mutex
|
||||
megolmEncryptLock sync.Mutex
|
||||
megolmDecryptLock sync.Mutex
|
||||
|
||||
otkUploadLock sync.Mutex
|
||||
lastOTKUpload time.Time
|
||||
|
|
@ -64,6 +66,9 @@ type OlmMachine struct {
|
|||
crossSigningPubkeys *CrossSigningPublicKeysCache
|
||||
|
||||
crossSigningPubkeysFetched bool
|
||||
|
||||
DeleteOutboundKeysOnAck bool
|
||||
RatchetKeysOnDecrypt bool
|
||||
}
|
||||
|
||||
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
|
||||
|
|
@ -365,6 +370,8 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
|||
return
|
||||
case *event.RoomKeyRequestEventContent:
|
||||
go mach.handleRoomKeyRequest(ctx, evt.Sender, content)
|
||||
case *event.BeeperRoomKeyAckEventContent:
|
||||
mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content)
|
||||
// verification cases
|
||||
case *event.VerificationStartEventContent:
|
||||
mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "")
|
||||
|
|
@ -472,9 +479,9 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
|
|||
return err
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string) {
|
||||
func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey)
|
||||
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create inbound group session")
|
||||
return
|
||||
|
|
@ -539,7 +546,14 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
|
|||
return
|
||||
}
|
||||
|
||||
mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey)
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
|
||||
maxMessages = config.RotationPeriodMessages
|
||||
}
|
||||
mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) {
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
// store room key in new inbound group session
|
||||
decrypted.Content.ParseRaw(event.ToDeviceRoomKey)
|
||||
roomKeyEvt := decrypted.Content.AsRoomKey()
|
||||
igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey)
|
||||
igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0)
|
||||
if err != nil {
|
||||
t.Errorf("Error creating inbound megolm session: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,6 +89,11 @@ func (session *OlmSession) Decrypt(ciphertext string, msgType id.OlmMsgType) ([]
|
|||
return msg, err
|
||||
}
|
||||
|
||||
type RatchetSafety struct {
|
||||
NextIndex uint `json:"next_index"`
|
||||
MissedIndices []uint `json:"missed_indices,omitempty"`
|
||||
}
|
||||
|
||||
type InboundGroupSession struct {
|
||||
Internal olm.InboundGroupSession
|
||||
|
||||
|
|
@ -97,11 +102,16 @@ type InboundGroupSession struct {
|
|||
RoomID id.RoomID
|
||||
|
||||
ForwardingChains []string
|
||||
RatchetSafety RatchetSafety
|
||||
|
||||
ReceivedAt time.Time
|
||||
MaxAge int64
|
||||
MaxMessages int
|
||||
|
||||
id id.SessionID
|
||||
}
|
||||
|
||||
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string) (*InboundGroupSession, error) {
|
||||
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int) (*InboundGroupSession, error) {
|
||||
igs, err := olm.NewInboundGroupSession([]byte(sessionKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -112,6 +122,9 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
|
|||
SenderKey: senderKey,
|
||||
RoomID: roomID,
|
||||
ForwardingChains: nil,
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
MaxAge: maxAge.Milliseconds(),
|
||||
MaxMessages: maxMessages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
|
@ -254,31 +255,58 @@ func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession)
|
|||
return err
|
||||
}
|
||||
|
||||
func intishPtr[T int | int64](i T) *T {
|
||||
if i == 0 {
|
||||
return nil
|
||||
}
|
||||
return &i
|
||||
}
|
||||
|
||||
func datePtr(t time.Time) *time.Time {
|
||||
if t.IsZero() {
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
forwardingChains := strings.Join(session.ForwardingChains, ",")
|
||||
_, err := store.DB.Exec(`
|
||||
INSERT INTO crypto_megolm_inbound_session
|
||||
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
|
||||
}
|
||||
_, err = store.DB.Exec(`
|
||||
INSERT INTO crypto_megolm_inbound_session (
|
||||
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
|
||||
ratchet_safety, received_at, max_age, max_messages, account_id
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (session_id, account_id) DO UPDATE
|
||||
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
|
||||
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains
|
||||
`, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID)
|
||||
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
|
||||
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
|
||||
max_age=excluded.max_age, max_messages=excluded.max_messages
|
||||
`,
|
||||
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
|
||||
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
|
||||
store.AccountID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
var signingKey, forwardingChains, withheldCode sql.NullString
|
||||
var sessionBytes []byte
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT signing_key, session, forwarding_chains, withheld_code
|
||||
SELECT signing_key, session, forwarding_chains, withheld_code, ratchet_safety, received_at, max_age, max_messages
|
||||
FROM crypto_megolm_inbound_session
|
||||
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
|
||||
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
|
||||
roomID, senderKey, sessionID, store.AccountID,
|
||||
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode)
|
||||
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
|
|
@ -295,18 +323,38 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
|||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: senderKey,
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID) error {
|
||||
_, err := store.DB.Exec(`
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason='Session redacted', session=NULL, forwarding_chains=NULL, ratchet_safety=NULL
|
||||
WHERE session_id=$2 AND account_id=$3
|
||||
`, event.RoomKeyWithheldBeeperRedacted, sessionID, store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
|
||||
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID)
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6. $7)",
|
||||
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -336,8 +384,10 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
|||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
var signingKey, senderKey, forwardingChains sql.NullString
|
||||
var sessionBytes []byte
|
||||
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains)
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -350,12 +400,23 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
|||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
result = append(result, &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: id.Curve25519(senderKey.String),
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
})
|
||||
}
|
||||
return
|
||||
|
|
@ -363,7 +424,7 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
|||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages
|
||||
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
|
||||
roomID, store.AccountID,
|
||||
)
|
||||
|
|
@ -377,7 +438,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*Inbou
|
|||
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages
|
||||
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
|
||||
store.AccountID,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
-- v0 -> v8: Latest revision
|
||||
-- v0 -> v10: Latest revision
|
||||
CREATE TABLE IF NOT EXISTS crypto_account (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
|
|
@ -52,6 +52,10 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
|
|||
forwarding_chains bytea,
|
||||
withheld_code TEXT,
|
||||
withheld_reason TEXT,
|
||||
ratchet_safety jsonb,
|
||||
received_at timestamp,
|
||||
max_age BIGINT,
|
||||
max_messages INTEGER,
|
||||
PRIMARY KEY (account_id, session_id)
|
||||
);
|
||||
|
||||
|
|
|
|||
5
crypto/sql_store_upgrade/10-mark-ratchetable-keys.sql
Normal file
5
crypto/sql_store_upgrade/10-mark-ratchetable-keys.sql
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
-- v10: Add flag for megolm sessions to mark them as safe to delete
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb;
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp;
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT;
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER;
|
||||
|
|
@ -54,6 +54,8 @@ type Store interface {
|
|||
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
|
||||
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
|
||||
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
|
||||
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
|
||||
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID) error
|
||||
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
|
||||
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
|
||||
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
|
||||
|
|
@ -262,6 +264,14 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey,
|
|||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) error {
|
||||
gs.lock.Lock()
|
||||
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
|
||||
err := gs.save()
|
||||
gs.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
|
||||
room, ok := gs.WithheldGroupSessions[roomID]
|
||||
if !ok {
|
||||
|
|
|
|||
|
|
@ -49,3 +49,9 @@ type BeeperRetryMetadata struct {
|
|||
RetryCount int `json:"retry_count"`
|
||||
// last_retry is also present, but not used by bridges
|
||||
}
|
||||
|
||||
type BeeperRoomKeyAckEventContent struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
SessionID id.SessionID `json:"session_id"`
|
||||
FirstMessageIndex int `json:"first_message_index"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,6 +80,8 @@ var TypeMap = map[Type]reflect.Type{
|
|||
|
||||
ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}),
|
||||
|
||||
ToDeviceBeeperRoomKeyAck: reflect.TypeOf(BeeperRoomKeyAckEventContent{}),
|
||||
|
||||
CallInvite: reflect.TypeOf(CallInviteEventContent{}),
|
||||
CallCandidates: reflect.TypeOf(CallCandidatesEventContent{}),
|
||||
CallAnswer: reflect.TypeOf(CallAnswerEventContent{}),
|
||||
|
|
|
|||
|
|
@ -135,6 +135,8 @@ const (
|
|||
RoomKeyWithheldUnauthorized RoomKeyWithheldCode = "m.unauthorised"
|
||||
RoomKeyWithheldUnavailable RoomKeyWithheldCode = "m.unavailable"
|
||||
RoomKeyWithheldNoOlmSession RoomKeyWithheldCode = "m.no_olm"
|
||||
|
||||
RoomKeyWithheldBeeperRedacted RoomKeyWithheldCode = "com.beeper.redacted"
|
||||
)
|
||||
|
||||
type RoomKeyWithheldEventContent struct {
|
||||
|
|
|
|||
|
|
@ -124,7 +124,8 @@ func (et *Type) GuessClass() TypeClass {
|
|||
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
|
||||
CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type:
|
||||
return MessageEventType
|
||||
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type:
|
||||
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
|
||||
ToDeviceBeeperRoomKeyAck.Type:
|
||||
return ToDeviceEventType
|
||||
default:
|
||||
return UnknownEventType
|
||||
|
|
@ -253,4 +254,6 @@ var (
|
|||
ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType}
|
||||
|
||||
ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType}
|
||||
|
||||
ToDeviceBeeperRoomKeyAck = Type{"com.beeper.room_key.ack", ToDeviceEventType}
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue